From b870493fed72c1fe5582ce383f81887f08be9f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 10:26:46 +0100 Subject: [PATCH 01/47] server: create Customer model and start to link logic to it --- ...09-27-1727_add_checkout_computer_tax_id.py | 4 +- server/polar/benefit/benefits/ads.py | 6 +- server/polar/benefit/benefits/base.py | 22 +-- server/polar/benefit/benefits/custom.py | 6 +- server/polar/benefit/benefits/discord.py | 14 +- .../polar/benefit/benefits/downloadables.py | 10 +- .../benefit/benefits/github_repository.py | 14 +- server/polar/benefit/benefits/license_keys.py | 12 +- server/polar/benefit/service/benefit_grant.py | 133 ++++++++---------- server/polar/checkout/service.py | 74 +++++----- server/polar/customer/__init__.py | 0 server/polar/customer/service.py | 51 +++++++ server/polar/eventstream/service.py | 6 + server/polar/{checkout => kit}/tax.py | 0 server/polar/license_key/schemas.py | 6 +- server/polar/license_key/service.py | 33 +++-- server/polar/models/__init__.py | 2 + server/polar/models/benefit_grant.py | 15 +- server/polar/models/checkout.py | 10 +- server/polar/models/customer.py | 39 +++++ server/polar/models/downloadable.py | 12 +- server/polar/models/license_key.py | 10 +- server/polar/models/order.py | 10 +- server/polar/models/subscription.py | 10 +- server/polar/models/transaction.py | 12 +- server/polar/order/service.py | 74 ++++------ server/polar/order/sorting.py | 2 +- server/polar/storefront/schemas.py | 5 +- server/polar/storefront/service.py | 36 ++--- server/polar/subscription/service.py | 109 +++++--------- server/polar/user/schemas/downloadables.py | 4 +- server/polar/user/service/downloadables.py | 18 +-- server/tests/checkout/test_endpoints.py | 2 +- server/tests/checkout/test_service.py | 2 +- server/tests/checkout/test_tax.py | 2 +- 35 files changed, 385 insertions(+), 380 deletions(-) create mode 100644 server/polar/customer/__init__.py create mode 100644 server/polar/customer/service.py rename server/polar/{checkout => kit}/tax.py (100%) create mode 100644 server/polar/models/customer.py diff --git a/server/migrations/versions/2024-09-27-1727_add_checkout_computer_tax_id.py b/server/migrations/versions/2024-09-27-1727_add_checkout_computer_tax_id.py index 6d9346ee12..5d4bfcefde 100644 --- a/server/migrations/versions/2024-09-27-1727_add_checkout_computer_tax_id.py +++ b/server/migrations/versions/2024-09-27-1727_add_checkout_computer_tax_id.py @@ -10,7 +10,7 @@ from alembic import op # Polar Custom Imports -import polar.checkout.tax +import polar.kit.tax # revision identifiers, used by Alembic. revision = "e4473617a8e9" @@ -25,7 +25,7 @@ def upgrade() -> None: "checkouts", sa.Column( "customer_tax_id", - polar.checkout.tax.TaxIDType(astext_type=sa.Text()), + polar.kit.tax.TaxIDType(astext_type=sa.Text()), nullable=True, ), ) diff --git a/server/polar/benefit/benefits/ads.py b/server/polar/benefit/benefits/ads.py index f8dffcf415..53c57a45ca 100644 --- a/server/polar/benefit/benefits/ads.py +++ b/server/polar/benefit/benefits/ads.py @@ -1,7 +1,7 @@ from typing import Any, cast from polar.auth.models import AuthSubject -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import BenefitAds, BenefitAdsProperties from polar.models.benefit_grant import BenefitGrantAdsProperties @@ -14,7 +14,7 @@ class BenefitAdsService( async def grant( self, benefit: BenefitAds, - user: User, + customer: Customer, grant_properties: BenefitGrantAdsProperties, *, update: bool = False, @@ -28,7 +28,7 @@ async def grant( async def revoke( self, benefit: BenefitAds, - user: User, + customer: Customer, grant_properties: BenefitGrantAdsProperties, *, attempt: int = 1, diff --git a/server/polar/benefit/benefits/base.py b/server/polar/benefit/benefits/base.py index a97c21bb29..f599e390d6 100644 --- a/server/polar/benefit/benefits/base.py +++ b/server/polar/benefit/benefits/base.py @@ -2,7 +2,7 @@ from polar.auth.models import AuthSubject from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError -from polar.models import Benefit, Organization, User +from polar.models import Benefit, Customer, Organization, User from polar.models.benefit import BenefitProperties from polar.notifications.notification import ( BenefitPreconditionErrorNotificationContextualPayload, @@ -93,19 +93,19 @@ def __init__(self, session: AsyncSession, redis: Redis) -> None: async def grant( self, benefit: B, - user: User, + customer: Customer, grant_properties: BGP, *, update: bool = False, attempt: int = 1, ) -> BGP: """ - Executes the logic to grant a benefit to a backer. + Executes the logic to grant a benefit to a customer. Args: benefit: The Benefit to grant. - user: The backer user. - grant_properties: Stored properties for this specific benefit and user. + customer: The customer. + grant_properties: Stored properties for this specific benefit and customer. Might be available at this stage if we're updating an already granted benefit. update: Whether we are updating an already granted benefit. @@ -113,7 +113,7 @@ async def grant( Useful for the worker to implement retry logic. Returns: - A dictionary with data to store for this specific benefit and user. + A dictionary with data to store for this specific benefit and customer. For example, it can be useful to store external identifiers that may help when updating the grant or revoking it. **Existing properties will be overriden, so be sure to include all the data @@ -129,23 +129,23 @@ async def grant( async def revoke( self, benefit: B, - user: User, + customer: Customer, grant_properties: BGP, *, attempt: int = 1, ) -> BGP: """ - Executes the logic to revoke a benefit from a backer. + Executes the logic to revoke a benefit from a customer. Args: benefit: The Benefit to revoke. - user: The backer user. - grant_properties: Stored properties for this specific benefit and user. + customer: The customer. + grant_properties: Stored properties for this specific benefit and customer. attempt: Number of times we attempted to revoke the benefit. Useful for the worker to implement retry logic. Returns: - A dictionary with data to store for this specific benefit and user. + A dictionary with data to store for this specific benefit and customer. For example, it can be useful to store external identifiers that may help when updating the grant or revoking it. **Existing properties will be overriden, so be sure to include all the data diff --git a/server/polar/benefit/benefits/custom.py b/server/polar/benefit/benefits/custom.py index a224b058e6..70e8b09db8 100644 --- a/server/polar/benefit/benefits/custom.py +++ b/server/polar/benefit/benefits/custom.py @@ -1,7 +1,7 @@ from typing import Any, cast from polar.auth.models import AuthSubject -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import BenefitCustom, BenefitCustomProperties from polar.models.benefit_grant import BenefitGrantCustomProperties @@ -16,7 +16,7 @@ class BenefitCustomService( async def grant( self, benefit: BenefitCustom, - user: User, + customer: Customer, grant_properties: BenefitGrantCustomProperties, *, update: bool = False, @@ -27,7 +27,7 @@ async def grant( async def revoke( self, benefit: BenefitCustom, - user: User, + customer: Customer, grant_properties: BenefitGrantCustomProperties, *, attempt: int = 1, diff --git a/server/polar/benefit/benefits/discord.py b/server/polar/benefit/benefits/discord.py index 6af3b24461..aa01ac96a4 100644 --- a/server/polar/benefit/benefits/discord.py +++ b/server/polar/benefit/benefits/discord.py @@ -9,7 +9,7 @@ from polar.integrations.discord.service import discord_bot as discord_bot_service from polar.integrations.discord.service import discord_user as discord_user_service from polar.logging import Logger -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import BenefitDiscord, BenefitDiscordProperties from polar.models.benefit_grant import BenefitGrantDiscordProperties from polar.notifications.notification import ( @@ -70,7 +70,7 @@ class BenefitDiscordService( async def grant( self, benefit: BenefitDiscord, - user: User, + customer: Customer, grant_properties: BenefitGrantDiscordProperties, *, update: bool = False, @@ -78,7 +78,7 @@ async def grant( ) -> BenefitGrantDiscordProperties: bound_logger = log.bind( benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), ) bound_logger.debug("Grant benefit") @@ -94,7 +94,9 @@ async def grant( bound_logger.debug( "Revoke before granting because guild or role have changed" ) - await self.revoke(benefit, user, grant_properties, attempt=attempt) + await self.revoke(benefit, customer, grant_properties, attempt=attempt) + + # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User try: account = await discord_user_service.get_oauth_account(self.session, user) @@ -133,14 +135,14 @@ async def grant( async def revoke( self, benefit: BenefitDiscord, - user: User, + customer: Customer, grant_properties: BenefitGrantDiscordProperties, *, attempt: int = 1, ) -> BenefitGrantDiscordProperties: bound_logger = log.bind( benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), ) guild_id = grant_properties.get("guild_id") diff --git a/server/polar/benefit/benefits/downloadables.py b/server/polar/benefit/benefits/downloadables.py index 6ec9d74f68..677b3f996c 100644 --- a/server/polar/benefit/benefits/downloadables.py +++ b/server/polar/benefit/benefits/downloadables.py @@ -8,7 +8,7 @@ from polar.auth.models import AuthSubject from polar.benefit import schemas as benefit_schemas from polar.logging import Logger -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import BenefitDownloadables, BenefitDownloadablesProperties from polar.models.benefit_grant import BenefitGrantDownloadablesProperties from polar.user.service.downloadables import downloadable as downloadable_service @@ -35,7 +35,7 @@ class BenefitDownloadablesService( async def grant( self, benefit: BenefitDownloadables, - user: User, + customer: Customer, grant_properties: BenefitGrantDownloadablesProperties, *, update: bool = False, @@ -49,7 +49,7 @@ async def grant( for file_id in file_ids: downloadable = await downloadable_service.grant_for_benefit_file( self.session, - user=user, + customer=customer, benefit_id=benefit.id, file_id=file_id, ) @@ -63,14 +63,14 @@ async def grant( async def revoke( self, benefit: BenefitDownloadables, - user: User, + customer: Customer, grant_properties: BenefitGrantDownloadablesProperties, *, attempt: int = 1, ) -> BenefitGrantDownloadablesProperties: await downloadable_service.revoke_for_benefit( self.session, - user=user, + customer=customer, benefit_id=benefit.id, ) return {} diff --git a/server/polar/benefit/benefits/github_repository.py b/server/polar/benefit/benefits/github_repository.py index 660cbaf712..a9df4e078e 100644 --- a/server/polar/benefit/benefits/github_repository.py +++ b/server/polar/benefit/benefits/github_repository.py @@ -13,7 +13,7 @@ github_repository_benefit_user_service, ) from polar.logging import Logger -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import ( BenefitGitHubRepository, BenefitGitHubRepositoryProperties, @@ -80,7 +80,7 @@ class BenefitGitHubRepositoryService( async def grant( self, benefit: BenefitGitHubRepository, - user: User, + customer: Customer, grant_properties: BenefitGrantGitHubRepositoryProperties, *, update: bool = False, @@ -88,7 +88,7 @@ async def grant( ) -> BenefitGrantGitHubRepositoryProperties: bound_logger = log.bind( benefit_id=str(benefit.id), - user_id=str(user.id), + user_id=str(customer.id), ) bound_logger.debug("Grant benefit") @@ -98,6 +98,8 @@ async def grant( repository_name = benefit.properties["repository_name"] permission = benefit.properties["permission"] + # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User + # When inviting users: Use the users identity from the "main" Polar GitHub App oauth_account = user.get_oauth_account(OAuthPlatform.github) if oauth_account is None or oauth_account.account_username is None: @@ -159,14 +161,14 @@ async def grant( async def revoke( self, benefit: BenefitGitHubRepository, - user: User, + customer: Customer, grant_properties: BenefitGrantGitHubRepositoryProperties, *, attempt: int = 1, ) -> BenefitGrantGitHubRepositoryProperties: bound_logger = log.bind( benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), ) if benefit.properties["repository_id"]: @@ -178,6 +180,8 @@ async def revoke( repository_owner = benefit.properties["repository_owner"] repository_name = benefit.properties["repository_name"] + # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User + oauth_account = user.get_oauth_account(OAuthPlatform.github) if oauth_account is None or oauth_account.account_username is None: raise diff --git a/server/polar/benefit/benefits/license_keys.py b/server/polar/benefit/benefits/license_keys.py index 0d8a60966e..245e434ca1 100644 --- a/server/polar/benefit/benefits/license_keys.py +++ b/server/polar/benefit/benefits/license_keys.py @@ -8,7 +8,7 @@ from polar.auth.models import AuthSubject from polar.license_key.service import license_key as license_key_service from polar.logging import Logger -from polar.models import Organization, User +from polar.models import Customer, Organization, User from polar.models.benefit import BenefitLicenseKeys, BenefitLicenseKeysProperties from polar.models.benefit_grant import BenefitGrantLicenseKeysProperties @@ -31,7 +31,7 @@ class BenefitLicenseKeysService( async def grant( self, benefit: BenefitLicenseKeys, - user: User, + customer: Customer, grant_properties: BenefitGrantLicenseKeysProperties, *, update: bool = False, @@ -43,7 +43,7 @@ async def grant( key = await license_key_service.user_grant( self.session, - user=user, + customer=customer, benefit=benefit, license_key_id=current_lk_id, ) @@ -55,7 +55,7 @@ async def grant( async def revoke( self, benefit: BenefitLicenseKeys, - user: User, + customer: Customer, grant_properties: BenefitGrantLicenseKeysProperties, *, attempt: int = 1, @@ -64,7 +64,7 @@ async def revoke( if not license_key_id: log.info( "license_key.revoke.skip", - user_id=user.id, + customer_id=customer.id, benefit_id=benefit.id, message="No license key to revoke", ) @@ -72,7 +72,7 @@ async def revoke( await license_key_service.user_revoke( self.session, - user=user, + customer=customer, benefit=benefit, license_key_id=UUID(license_key_id), ) diff --git a/server/polar/benefit/service/benefit_grant.py b/server/polar/benefit/service/benefit_grant.py index 1ff4cdd09d..d2577df3d3 100644 --- a/server/polar/benefit/service/benefit_grant.py +++ b/server/polar/benefit/service/benefit_grant.py @@ -8,22 +8,15 @@ from polar.benefit.benefits import BenefitPreconditionError, get_benefit_service from polar.benefit.schemas import BenefitGrantWebhook +from polar.customer.service import customer as customer_service from polar.eventstream.service import publish as eventstream_publish from polar.exceptions import PolarError from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.logging import Logger -from polar.models import ( - Benefit, - BenefitGrant, - OAuthAccount, - Product, - ProductBenefit, - User, -) +from polar.models import Benefit, BenefitGrant, Customer, Product, ProductBenefit from polar.models.benefit import BenefitProperties, BenefitType from polar.models.benefit_grant import BenefitGrantPropertiesBase, BenefitGrantScope -from polar.models.user import OAuthPlatform from polar.models.webhook_endpoint import WebhookEventType from polar.notifications.notification import ( BenefitPreconditionErrorNotificationPayload, @@ -34,11 +27,8 @@ from polar.organization.service import organization as organization_service from polar.postgres import AsyncSession, sql from polar.redis import Redis -from polar.user.service.user import user as user_service from polar.webhook.service import webhook as webhook_service -from polar.webhook.webhooks import ( - WebhookPayloadTypeAdapter, -) +from polar.webhook.webhooks import WebhookPayloadTypeAdapter from polar.worker import enqueue_job from .benefit_grant_scope import resolve_scope, scope_to_args @@ -116,8 +106,7 @@ async def list( benefit: Benefit, *, is_granted: bool | None = None, - user_id: UUID | None = None, - github_user_id: int | None = None, + customer_id: Sequence[UUID] | None = None, pagination: PaginationParams, ) -> tuple[Sequence[BenefitGrant], int]: statement = ( @@ -132,18 +121,8 @@ async def list( if is_granted is not None: statement = statement.where(BenefitGrant.is_granted.is_(is_granted)) - if user_id is not None: - statement = statement.where(BenefitGrant.user_id == user_id) - - if github_user_id is not None: - oauth_account_statement = select(OAuthAccount.user_id).where( - OAuthAccount.deleted_at.is_(None), - OAuthAccount.platform == OAuthPlatform.github, - OAuthAccount.account_id == str(github_user_id), - ) - statement = statement.where( - BenefitGrant.user_id.in_(oauth_account_statement) - ) + if customer_id is not None: + statement = statement.where(BenefitGrant.customer_id.in_(customer_id)) return await paginate(session, statement, pagination=pagination) @@ -151,18 +130,22 @@ async def grant_benefit( self, session: AsyncSession, redis: Redis, - user: User, + customer: Customer, benefit: Benefit, *, attempt: int = 1, **scope: Unpack[BenefitGrantScope], ) -> BenefitGrant: - log.info("Granting benefit", benefit_id=str(benefit.id), user_id=str(user.id)) + log.info( + "Granting benefit", benefit_id=str(benefit.id), customer_id=str(customer.id) + ) - grant = await self.get_by_benefit_and_scope(session, user, benefit, **scope) + grant = await self.get_by_benefit_and_scope(session, customer, benefit, **scope) if grant is None: - grant = BenefitGrant(user=user, benefit=benefit, properties={}, **scope) + grant = BenefitGrant( + customer=customer, benefit=benefit, properties={}, **scope + ) session.add(grant) elif grant.is_granted: return grant @@ -172,12 +155,12 @@ async def grant_benefit( try: properties = await benefit_service.grant( benefit, - user, + customer, grant.properties, attempt=attempt, ) except BenefitPreconditionError as e: - await self.handle_precondition_error(session, e, user, benefit, **scope) + await self.handle_precondition_error(session, e, customer, benefit, **scope) grant.granted_at = None else: grant.properties = properties @@ -189,13 +172,13 @@ async def grant_benefit( await eventstream_publish( "benefit.granted", {"benefit_id": benefit.id, "benefit_type": benefit.type}, - user_id=user.id, + customer_id=customer.id, ) log.info( "Benefit granted", benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), grant_id=str(grant.id), ) @@ -212,18 +195,22 @@ async def revoke_benefit( self, session: AsyncSession, redis: Redis, - user: User, + customer: Customer, benefit: Benefit, *, attempt: int = 1, **scope: Unpack[BenefitGrantScope], ) -> BenefitGrant: - log.info("Revoking benefit", benefit_id=str(benefit.id), user_id=str(user.id)) + log.info( + "Revoking benefit", benefit_id=str(benefit.id), customer_id=str(customer.id) + ) - grant = await self.get_by_benefit_and_scope(session, user, benefit, **scope) + grant = await self.get_by_benefit_and_scope(session, customer, benefit, **scope) if grant is None: - grant = BenefitGrant(user=user, benefit=benefit, properties={}, **scope) + grant = BenefitGrant( + customer=customer, benefit=benefit, properties={}, **scope + ) session.add(grant) elif grant.is_revoked: return grant @@ -235,13 +222,13 @@ async def revoke_benefit( # * If the service requires grants to be revoked individually # * If there is only one grant remaining for this benefit, # so the benefit remains if other grants exist via other purchases - other_grants = await self._get_granted_by_benefit_and_user( - session, benefit, user + other_grants = await self._get_granted_by_benefit_and_customer( + session, benefit, customer ) if benefit_service.should_revoke_individually or len(other_grants) < 2: properties = await benefit_service.revoke( benefit, - user, + customer, grant.properties, attempt=attempt, ) @@ -255,13 +242,13 @@ async def revoke_benefit( await eventstream_publish( "benefit.revoked", {"benefit_id": benefit.id, "benefit_type": benefit.type}, - user_id=user.id, + customer_id=customer.id, ) log.info( "Benefit revoked", benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), grant_id=str(grant.id), ) @@ -278,7 +265,7 @@ async def enqueue_benefits_grants( self, session: AsyncSession, task: Literal["grant", "revoke"], - user: User, + customer: Customer, product: Product, **scope: Unpack[BenefitGrantScope], ) -> None: @@ -289,7 +276,7 @@ async def enqueue_benefits_grants( for benefit in product.benefits: enqueue_job( f"benefit.{task}", - user_id=user.id, + customer_id=customer.id, benefit_id=benefit.id, **scope_to_args(scope), ) @@ -297,7 +284,7 @@ async def enqueue_benefits_grants( for outdated_grant in outdated_grants: enqueue_job( "benefit.revoke", - user_id=user.id, + customer_id=customer.id, benefit_id=outdated_grant.benefit_id, **scope_to_args(scope), ) @@ -331,22 +318,22 @@ async def update_benefit_grant( benefit = grant.benefit - user = await user_service.get(session, grant.user_id) - assert user is not None + customer = await customer_service.get(session, grant.customer_id) + assert customer is not None previous_properties = grant.properties benefit_service = get_benefit_service(benefit.type, session, redis) try: properties = await benefit_service.grant( benefit, - user, + customer, grant.properties, update=True, attempt=attempt, ) except BenefitPreconditionError as e: scope = await resolve_scope(session, grant.scope) - await self.handle_precondition_error(session, e, user, benefit, **scope) + await self.handle_precondition_error(session, e, customer, benefit, **scope) grant.granted_at = None else: grant.properties = properties @@ -382,17 +369,17 @@ async def delete_benefit_grant( if grant.is_revoked: return grant - await session.refresh(grant, {"subscription", "benefit"}) + await session.refresh(grant, {"benefit"}) benefit = grant.benefit - user = await user_service.get(session, grant.user_id) - assert user is not None + customer = await customer_service.get(session, grant.customer_id) + assert customer is not None previous_properties = grant.properties benefit_service = get_benefit_service(benefit.type, session, redis) properties = await benefit_service.revoke( benefit, - user, + customer, grant.properties, attempt=attempt, ) @@ -414,24 +401,24 @@ async def handle_precondition_error( self, session: AsyncSession, error: BenefitPreconditionError, - user: User, + customer: Customer, benefit: Benefit, **scope: Unpack[BenefitGrantScope], ) -> None: if error.payload is None: log.warning( - "A precondition error was raised but the user was not notified. " + "A precondition error was raised but the customer was not notified. " "We probably should implement a notification for this error.", benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), scope=scope, ) return log.info( - "Precondition error while granting benefit. User was informed.", + "Precondition error while granting benefit. Customer was informed.", benefit_id=str(benefit.id), - user_id=str(user.id), + customer_id=str(customer.id), ) # Disable the notification for now as it's a bit noisy for some use-cases @@ -470,21 +457,23 @@ async def handle_precondition_error( async def enqueue_grants_after_precondition_fulfilled( self, session: AsyncSession, - user: User, + customer: Customer, benefit_type: BenefitType, ) -> None: log.info( "Enqueueing benefit grants after precondition fulfilled", - user_id=str(user.id), + customer_id=str(customer.id), benefit_type=benefit_type, ) - grants = await self._get_by_user_and_benefit_type(session, user, benefit_type) + grants = await self._get_by_customer_and_benefit_type( + session, customer, benefit_type + ) for grant in grants: if not grant.is_granted and not grant.is_revoked: enqueue_job( "benefit.grant", - user_id=user.id, + customer_id=customer.id, benefit_id=grant.benefit_id, **grant.scope, ) @@ -492,12 +481,12 @@ async def enqueue_grants_after_precondition_fulfilled( async def get_by_benefit_and_scope( self, session: AsyncSession, - user: User, + customer: Customer, benefit: Benefit, **scope: Unpack[BenefitGrantScope], ) -> BenefitGrant | None: statement = select(BenefitGrant).where( - BenefitGrant.user_id == user.id, + BenefitGrant.customer_id == customer.id, BenefitGrant.benefit_id == benefit.id, BenefitGrant.deleted_at.is_(None), BenefitGrant.scope == scope, @@ -518,15 +507,15 @@ async def _get_granted_by_benefit( result = await session.execute(statement) return result.scalars().all() - async def _get_granted_by_benefit_and_user( + async def _get_granted_by_benefit_and_customer( self, session: AsyncSession, benefit: Benefit, - user: User, + customer: Customer, ) -> Sequence[BenefitGrant]: statement = select(BenefitGrant).where( BenefitGrant.benefit_id == benefit.id, - BenefitGrant.user_id == user.id, + BenefitGrant.customer_id == customer.id, BenefitGrant.is_granted.is_(True), BenefitGrant.deleted_at.is_(None), ) @@ -534,17 +523,17 @@ async def _get_granted_by_benefit_and_user( result = await session.execute(statement) return result.scalars().all() - async def _get_by_user_and_benefit_type( + async def _get_by_customer_and_benefit_type( self, session: AsyncSession, - user: User, + customer: Customer, benefit_type: BenefitType, ) -> Sequence[BenefitGrant]: statement = ( select(BenefitGrant) .join(Benefit) .where( - BenefitGrant.user_id == user.id, + BenefitGrant.customer_id == customer.id, Benefit.type == benefit_type, ) ) diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index d94da6019c..ab92f740b0 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -10,7 +10,6 @@ from polar.auth.models import ( Anonymous, AuthSubject, - is_direct_user, is_organization, is_user, ) @@ -22,7 +21,6 @@ CheckoutUpdate, CheckoutUpdatePublic, ) -from polar.checkout.tax import TaxID, to_stripe_tax_id, validate_tax_id from polar.config import settings from polar.custom_field.data import validate_custom_field_data from polar.discount.service import DiscountNotRedeemableError @@ -37,12 +35,14 @@ from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.kit.sorting import Sorting +from polar.kit.tax import TaxID, to_stripe_tax_id, validate_tax_id from polar.kit.utils import utc_now from polar.locker import Locker from polar.logging import Logger from polar.models import ( Checkout, CheckoutLink, + Customer, Discount, Organization, Product, @@ -63,14 +63,13 @@ from polar.postgres import AsyncSession from polar.product.service.product import product as product_service from polar.product.service.product_price import product_price as product_price_service -from polar.user.service.user import user as user_service from polar.webhook.service import webhook as webhook_service from polar.worker import enqueue_job +from ..kit.tax import TaxCalculationError, calculate_tax from . import ip_geolocation from .eventstream import CheckoutEvent, publish_checkout_event from .sorting import CheckoutSortProperty -from .tax import TaxCalculationError, calculate_tax log: Logger = structlog.get_logger() @@ -303,7 +302,7 @@ async def create( ) from e subscription: Subscription | None = None - customer: User | None = None + customer: Customer | None = None if checkout_create.subscription_id is not None: subscription, customer = await self._get_validated_subscription( session, checkout_create.subscription_id, product.organization_id @@ -458,18 +457,7 @@ async def client_create( customer=None, subscription=None, ) - if is_direct_user(auth_subject): - checkout.customer = auth_subject.subject - checkout.customer_email = auth_subject.subject.email - if checkout_create.subscription_id is not None: - subscription, _ = await self._get_validated_subscription( - session, - checkout_create.subscription_id, - product.organization_id, - auth_subject.subject.id, - ) - checkout.subscription = subscription - elif checkout_create.customer_email is not None: + if checkout_create.customer_email is not None: checkout.customer_email = checkout_create.customer_email if checkout.payment_processor == PaymentProcessor.stripe: @@ -723,9 +711,10 @@ async def _confirm_inner( assert checkout.customer_email is not None if checkout.payment_processor == PaymentProcessor.stripe: - stripe_customer_id = await self._create_or_update_stripe_customer( - session, checkout - ) + customer = await self._create_or_update_customer(session, checkout) + checkout.customer = customer + stripe_customer_id = customer.stripe_customer_id + assert stripe_customer_id is not None checkout.payment_processor_metadata = {"customer_id": stripe_customer_id} if checkout.is_payment_required or checkout.is_payment_setup_required: @@ -1192,8 +1181,7 @@ async def _get_validated_subscription( session: AsyncSession, subscription_id: uuid.UUID, organization_id: uuid.UUID, - user_id: uuid.UUID | None = None, - ) -> tuple[Subscription, User]: + ) -> tuple[Subscription, Customer]: statement = ( select(Subscription) .join(Product) @@ -1204,11 +1192,9 @@ async def _get_validated_subscription( .options( contains_eager(Subscription.product), joinedload(Subscription.price), - joinedload(Subscription.user), + joinedload(Subscription.customer), ) ) - if user_id is not None: - statement = statement.where(Subscription.user_id == user_id) result = await session.execute(statement) subscription = result.scalars().unique().one_or_none() @@ -1236,7 +1222,7 @@ async def _get_validated_subscription( ] ) - return subscription, subscription.user + return subscription, subscription.customer async def _update_checkout( self, @@ -1525,21 +1511,26 @@ def _get_required_confirm_fields(self, checkout: Checkout) -> set[str]: fields.update({"customer_name", "customer_billing_address"}) return fields - async def _create_or_update_stripe_customer( + async def _create_or_update_customer( self, session: AsyncSession, checkout: Checkout - ) -> str: - assert checkout.customer_email is not None + ) -> Customer: + customer = checkout.customer + if customer is None: + assert checkout.customer_email is not None + customer = Customer( + email=checkout.customer_email, + email_verified=False, + stripe_customer_id=None, + name=checkout.customer_name, + billing_address=checkout.customer_billing_address, + tax_id=checkout.customer_tax_id, + organization=checkout.organization, + ) - stripe_customer_id: str | None = None - if checkout.customer_id is not None: - user = await user_service.get(session, checkout.customer_id) - if user is not None and user.stripe_customer_id is not None: - stripe_customer_id = user.stripe_customer_id + stripe_customer_id = customer.stripe_customer_id if stripe_customer_id is None: - create_params: stripe_lib.Customer.CreateParams = { - "email": checkout.customer_email - } + create_params: stripe_lib.Customer.CreateParams = {"email": customer.email} if checkout.customer_name is not None: create_params["name"] = checkout.customer_name if checkout.customer_billing_address is not None: @@ -1551,9 +1542,7 @@ async def _create_or_update_stripe_customer( stripe_customer = await stripe_service.create_customer(**create_params) stripe_customer_id = stripe_customer.id else: - update_params: stripe_lib.Customer.ModifyParams = { - "email": checkout.customer_email - } + update_params: stripe_lib.Customer.ModifyParams = {"email": customer.email} if checkout.customer_name is not None: update_params["name"] = checkout.customer_name if checkout.customer_billing_address is not None: @@ -1566,7 +1555,10 @@ async def _create_or_update_stripe_customer( **update_params, ) - return stripe_customer_id + session.add(customer) + await session.flush() + + return customer async def _get_eager_loaded_checkout( self, session: AsyncSession, checkout_id: uuid.UUID diff --git a/server/polar/customer/__init__.py b/server/polar/customer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py new file mode 100644 index 0000000000..4c2ba415cd --- /dev/null +++ b/server/polar/customer/service.py @@ -0,0 +1,51 @@ +from sqlalchemy import select +from stripe import Customer as StripeCustomer + +from polar.kit.services import ResourceServiceReader +from polar.models import Customer, Organization +from polar.postgres import AsyncSession + + +class CustomerService(ResourceServiceReader[Customer]): + async def get_by_id_and_organization( + self, session: AsyncSession, id: str, organization: Organization + ) -> Customer | None: + statement = select(Customer).where( + Customer.deleted_at.is_(None), + Customer.id == id, + Customer.organization_id == organization.id, + ) + result = await session.execute(statement) + return result.scalar_one_or_none() + + async def get_by_stripe_customer_id( + self, session: AsyncSession, stripe_customer_id: str + ) -> Customer | None: + statement = select(Customer).where( + Customer.deleted_at.is_(None), + Customer.stripe_customer_id == stripe_customer_id, + ) + result = await session.execute(statement) + return result.scalar_one_or_none() + + async def create_from_stripe_customer( + self, + session: AsyncSession, + stripe_customer: StripeCustomer, + organization: Organization, + ) -> Customer: + customer = Customer( + email=stripe_customer.email, + email_verified=False, + stripe_customer_id=stripe_customer.id, + name=stripe_customer.name, + billing_address=stripe_customer.address, + # TODO: tax_id, + organization=organization, + ) + + session.add(customer) + return customer + + +customer = CustomerService(Customer) diff --git a/server/polar/eventstream/service.py b/server/polar/eventstream/service.py index aa5d8f017e..5059dae594 100644 --- a/server/polar/eventstream/service.py +++ b/server/polar/eventstream/service.py @@ -20,6 +20,7 @@ class Receivers(BaseModel): user_id: UUID | None = None organization_id: UUID | None = None checkout_client_secret: str | None = None + customer_id: UUID | None = None def generate_channel_name(self, scope: str, resource_id: UUID | str) -> str: return f"{scope}:{resource_id}" @@ -37,6 +38,9 @@ def get_channels(self) -> list[str]: self.generate_channel_name("checkout", self.checkout_client_secret) ) + if self.customer_id: + channels.append(self.generate_channel_name("customer", self.customer_id)) + return channels @@ -60,6 +64,7 @@ async def publish( user_id: UUID | None = None, organization_id: UUID | None = None, checkout_client_secret: str | None = None, + customer_id: UUID | None = None, *, run_in_worker: bool = True, redis: Redis | None = None, @@ -68,6 +73,7 @@ async def publish( user_id=user_id, organization_id=organization_id, checkout_client_secret=checkout_client_secret, + customer_id=customer_id, ) channels = receivers.get_channels() event = Event( diff --git a/server/polar/checkout/tax.py b/server/polar/kit/tax.py similarity index 100% rename from server/polar/checkout/tax.py rename to server/polar/kit/tax.py diff --git a/server/polar/license_key/schemas.py b/server/polar/license_key/schemas.py index 34abf22e71..db70e4fcee 100644 --- a/server/polar/license_key/schemas.py +++ b/server/polar/license_key/schemas.py @@ -113,7 +113,7 @@ class LicenseKeyUpdate(Schema): class LicenseKeyCreate(LicenseKeyUpdate): organization_id: UUID4 - user_id: UUID4 + customer_id: UUID4 benefit_id: BenefitID key: str @@ -143,7 +143,7 @@ def generate_expiration_dt( def build( cls, organization_id: UUID4, - user_id: UUID4, + customer_id: UUID4, benefit_id: UUID4, prefix: str | None = None, status: LicenseKeyStatus = LicenseKeyStatus.granted, @@ -165,7 +165,7 @@ def build( key = cls.generate_key(prefix=prefix) return cls( organization_id=organization_id, - user_id=user_id, + customer_id=customer_id, benefit_id=benefit_id, key=key, status=status, diff --git a/server/polar/license_key/service.py b/server/polar/license_key/service.py index bf761cff42..5983939199 100644 --- a/server/polar/license_key/service.py +++ b/server/polar/license_key/service.py @@ -12,6 +12,7 @@ from polar.kit.utils import utc_now from polar.models import ( Benefit, + Customer, LicenseKey, LicenseKeyActivation, Organization, @@ -104,13 +105,13 @@ async def get_by_grant_or_raise( *, id: UUID, organization_id: UUID, - user_id: UUID, + customer_id: UUID, benefit_id: UUID, ) -> LicenseKey: query = self._get_select_base().where( LicenseKey.id == id, LicenseKey.organization_id == organization_id, - LicenseKey.user_id == user_id, + LicenseKey.customer_id == customer_id, LicenseKey.benefit_id == benefit_id, ) result = await session.execute(query) @@ -328,7 +329,7 @@ async def activate( "license_key.activate.limit_reached", license_key_id=license_key.id, organization_id=license_key.organization_id, - user=license_key.user_id, + customer_id=license_key.customer_id, benefit_id=license_key.benefit_id, ) raise NotPermitted("License key activation limit already reached") @@ -346,7 +347,7 @@ async def activate( "license_key.activate", license_key_id=license_key.id, organization_id=license_key.organization_id, - user=license_key.user_id, + customer_id=license_key.customer_id, benefit_id=license_key.benefit_id, activation_id=instance.id, ) @@ -371,7 +372,7 @@ async def deactivate( "license_key.deactivate", license_key_id=license_key.id, organization_id=license_key.organization_id, - user=license_key.user_id, + customer_id=license_key.customer_id, benefit_id=license_key.benefit_id, activation_id=activation.id, ) @@ -381,14 +382,14 @@ async def user_grant( self, session: AsyncSession, *, - user: User, + customer: Customer, benefit: BenefitLicenseKeys, license_key_id: UUID | None = None, ) -> LicenseKey: props = benefit.properties create_schema = LicenseKeyCreate.build( organization_id=benefit.organization_id, - user_id=user.id, + customer_id=customer.id, benefit_id=benefit.id, prefix=props.get("prefix", None), limit_usage=props.get("limit_usage", None), @@ -398,7 +399,7 @@ async def user_grant( log.info( "license_key.grant.request", organization_id=benefit.organization_id, - user=user.id, + customer_id=customer.id, benefit_id=benefit.id, ) if license_key_id: @@ -424,7 +425,7 @@ async def user_update_grant( session, id=license_key_id, organization_id=create_schema.organization_id, - user_id=create_schema.user_id, + customer_id=create_schema.customer_id, benefit_id=create_schema.benefit_id, ) @@ -447,7 +448,7 @@ async def user_update_grant( "license_key.grant.update", license_key_id=key.id, organization_id=key.organization_id, - user=key.user_id, + customer_id=key.customer_id, benefit_id=key.benefit_id, ) return key @@ -466,7 +467,7 @@ async def user_create_grant( "license_key.grant.create", license_key_id=key.id, organization_id=key.organization_id, - user=key.user_id, + customer_id=key.customer_id, benefit_id=key.benefit_id, ) return key @@ -474,7 +475,7 @@ async def user_create_grant( async def user_revoke( self, session: AsyncSession, - user: User, + customer: Customer, benefit: BenefitLicenseKeys, license_key_id: UUID, ) -> LicenseKey: @@ -482,7 +483,7 @@ async def user_revoke( session, id=license_key_id, organization_id=benefit.organization_id, - user_id=user.id, + customer_id=customer.id, benefit_id=benefit.id, ) key.mark_revoked() @@ -492,7 +493,7 @@ async def user_revoke( "license_key.revoke", license_key_id=key.id, organization_id=key.organization_id, - user=key.user_id, + customer_id=key.customer_id, benefit_id=key.benefit_id, ) return key @@ -500,9 +501,7 @@ async def user_revoke( def _get_select_base(self) -> Select[tuple[LicenseKey]]: return ( select(LicenseKey) - .options( - joinedload(LicenseKey.user), - ) + .options(joinedload(LicenseKey.customer)) .where(LicenseKey.deleted_at.is_(None)) ) diff --git a/server/polar/models/__init__.py b/server/polar/models/__init__.py index 77e5428b96..d42f0cbac9 100644 --- a/server/polar/models/__init__.py +++ b/server/polar/models/__init__.py @@ -7,6 +7,7 @@ from .checkout import Checkout from .checkout_link import CheckoutLink from .custom_field import CustomField +from .customer import Customer from .discount import Discount from .discount_product import DiscountProduct from .discount_redemption import DiscountRedemption @@ -59,6 +60,7 @@ "BenefitGrant", "Checkout", "CheckoutLink", + "Customer", "CustomField", "Discount", "DiscountProduct", diff --git a/server/polar/models/benefit_grant.py b/server/polar/models/benefit_grant.py index 952585ab68..e1e1ca293d 100644 --- a/server/polar/models/benefit_grant.py +++ b/server/polar/models/benefit_grant.py @@ -26,7 +26,7 @@ from polar.kit.db.models import RecordModel if TYPE_CHECKING: - from polar.models import Benefit, Order, Subscription, User + from polar.models import Benefit, Customer, Order, Subscription class BenefitGrantScope(TypedDict, total=False): @@ -119,7 +119,10 @@ class BenefitGrant(RecordModel): __tablename__ = "benefit_grants" __table_args__ = ( UniqueConstraint( - "subscription_id", "user_id", "benefit_id", name="benefit_grants_sbu_key" + "subscription_id", + "customer_id", + "benefit_id", + name="benefit_grants_sbc_key", ), ) @@ -133,16 +136,16 @@ class BenefitGrant(RecordModel): "properties", JSONB, nullable=False, default=dict ) - user_id: Mapped[UUID] = mapped_column( + customer_id: Mapped[UUID] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="cascade"), + ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True, ) @declared_attr - def user(cls) -> Mapped["User"]: - return relationship("User", lazy="raise") + def customer(cls) -> Mapped["Customer"]: + return relationship("Customer", lazy="raise") benefit_id: Mapped[UUID] = mapped_column( Uuid, diff --git a/server/polar/models/checkout.py b/server/polar/models/checkout.py index a539776cac..4ef3c42621 100644 --- a/server/polar/models/checkout.py +++ b/server/polar/models/checkout.py @@ -19,7 +19,6 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, Mapper, declared_attr, mapped_column, relationship -from polar.checkout.tax import TaxID, TaxIDType from polar.config import settings from polar.custom_field.attachment import AttachedCustomFieldMixin from polar.custom_field.data import CustomFieldDataMixin @@ -27,14 +26,15 @@ from polar.kit.address import Address, AddressType from polar.kit.db.models import RecordModel from polar.kit.metadata import MetadataMixin +from polar.kit.tax import TaxID, TaxIDType from polar.kit.utils import utc_now +from .customer import Customer from .discount import Discount from .organization import Organization from .product import Product from .product_price import ProductPrice, ProductPriceFixed, ProductPriceFree from .subscription import Subscription -from .user import User def get_expires_at() -> datetime: @@ -112,13 +112,13 @@ def discount(cls) -> Mapped[Discount | None]: customer_id: Mapped[UUID | None] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="cascade"), + ForeignKey("customers.id", ondelete="set null"), nullable=True, ) @declared_attr - def customer(cls) -> Mapped[User | None]: - return relationship(User, lazy="raise") + def customer(cls) -> Mapped[Customer | None]: + return relationship(Customer, lazy="raise") customer_name: Mapped[str | None] = mapped_column( String, nullable=True, default=None diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py new file mode 100644 index 0000000000..b5d1adc380 --- /dev/null +++ b/server/polar/models/customer.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING +from uuid import UUID + +from sqlalchemy import Boolean, ForeignKey, String, Uuid +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship + +from polar.kit.address import Address, AddressType +from polar.kit.db.models import RecordModel +from polar.kit.metadata import MetadataMixin +from polar.kit.tax import TaxID, TaxIDType + +if TYPE_CHECKING: + from .organization import Organization + + +class Customer(MetadataMixin, RecordModel): + __tablename__ = "customers" + + email: Mapped[str] = mapped_column(String(320), nullable=False) + email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + stripe_customer_id: Mapped[str | None] = mapped_column( + String, nullable=True, default=None, unique=True + ) + + name: Mapped[str | None] = mapped_column(String, nullable=True, default=None) + billing_address: Mapped[Address | None] = mapped_column( + AddressType, nullable=True, default=None + ) + tax_id: Mapped[TaxID | None] = mapped_column(TaxIDType, nullable=True, default=None) + + organization_id: Mapped[UUID] = mapped_column( + Uuid, + ForeignKey("organizations.id", ondelete="cascade"), + nullable=False, + ) + + @declared_attr + def organization(cls) -> Mapped["Organization"]: + return relationship("Organization", lazy="raise") diff --git a/server/polar/models/downloadable.py b/server/polar/models/downloadable.py index b623217f88..ce04771ffc 100644 --- a/server/polar/models/downloadable.py +++ b/server/polar/models/downloadable.py @@ -15,8 +15,8 @@ from polar.kit.db.models import RecordModel from .benefit import Benefit +from .customer import Customer from .file import File -from .user import User class DownloadableStatus(StrEnum): @@ -26,7 +26,7 @@ class DownloadableStatus(StrEnum): class Downloadable(RecordModel): __tablename__ = "downloadables" - __table_args__ = (UniqueConstraint("user_id", "file_id", "benefit_id"),) + __table_args__ = (UniqueConstraint("customer_id", "file_id", "benefit_id"),) file_id: Mapped[UUID] = mapped_column( Uuid, ForeignKey("files.id"), nullable=False, index=True @@ -38,16 +38,16 @@ def file(cls) -> Mapped[File]: status: Mapped[DownloadableStatus] = mapped_column(String, nullable=False) - user_id: Mapped[UUID] = mapped_column( + customer_id: Mapped[UUID] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="cascade"), + ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True, ) @declared_attr - def user(cls) -> Mapped[User]: - return relationship("User", lazy="raise") + def customer(cls) -> Mapped[Customer]: + return relationship("Customer", lazy="raise") benefit_id: Mapped[UUID] = mapped_column( Uuid, diff --git a/server/polar/models/license_key.py b/server/polar/models/license_key.py index ae07899aa1..9a65b57ee2 100644 --- a/server/polar/models/license_key.py +++ b/server/polar/models/license_key.py @@ -16,7 +16,7 @@ from polar.kit.utils import utc_now from .benefit import BenefitLicenseKeys -from .user import User +from .customer import Customer if TYPE_CHECKING: from .license_key_activation import LicenseKeyActivation @@ -43,16 +43,16 @@ class LicenseKey(RecordModel): def organization(cls) -> Mapped["Organization"]: return relationship("Organization", lazy="raise") - user_id: Mapped[UUID] = mapped_column( + customer_id: Mapped[UUID] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="cascade"), + ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True, ) @declared_attr - def user(cls) -> Mapped[User]: - return relationship("User", lazy="raise") + def customer(cls) -> Mapped[Customer]: + return relationship("Customer", lazy="raise") benefit_id: Mapped[UUID] = mapped_column( Uuid, diff --git a/server/polar/models/order.py b/server/polar/models/order.py index 244399ff2c..3fd9951dab 100644 --- a/server/polar/models/order.py +++ b/server/polar/models/order.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: from polar.models import ( Checkout, + Customer, Discount, Organization, Product, ProductPrice, Subscription, - User, ) @@ -44,15 +44,15 @@ class Order(CustomFieldDataMixin, MetadataMixin, RecordModel): String, nullable=True, unique=True ) - user_id: Mapped[UUID] = mapped_column( + customer_id: Mapped[UUID] = mapped_column( Uuid, - ForeignKey("users.id"), + ForeignKey("customers.id"), nullable=False, ) @declared_attr - def user(cls) -> Mapped["User"]: - return relationship("User", lazy="raise") + def customer(cls) -> Mapped["Customer"]: + return relationship("Customer", lazy="raise") product_id: Mapped[UUID] = mapped_column( Uuid, diff --git a/server/polar/models/subscription.py b/server/polar/models/subscription.py index c0e1be9693..fc39965224 100644 --- a/server/polar/models/subscription.py +++ b/server/polar/models/subscription.py @@ -26,11 +26,11 @@ from polar.models import ( BenefitGrant, Checkout, + Customer, Discount, Organization, Product, ProductPrice, - User, ) @@ -95,16 +95,16 @@ class Subscription(CustomFieldDataMixin, MetadataMixin, RecordModel): TIMESTAMP(timezone=True), nullable=True, default=None ) - user_id: Mapped[UUID] = mapped_column( + customer_id: Mapped[UUID] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="cascade"), + ForeignKey("customers.id", ondelete="cascade"), nullable=False, index=True, ) @declared_attr - def user(cls) -> Mapped["User"]: - return relationship("User", lazy="raise") + def customer(cls) -> Mapped["Customer"]: + return relationship("Customer", lazy="raise") product_id: Mapped[UUID] = mapped_column( Uuid, diff --git a/server/polar/models/transaction.py b/server/polar/models/transaction.py index 426feb4785..f289d2faec 100644 --- a/server/polar/models/transaction.py +++ b/server/polar/models/transaction.py @@ -10,11 +10,11 @@ if TYPE_CHECKING: from polar.models import ( Account, + Customer, IssueReward, Order, Organization, Pledge, - User, ) @@ -263,17 +263,17 @@ class Transaction(RecordModel): def account(cls) -> Mapped["Account | None"]: return relationship("Account", lazy="raise") - payment_user_id: Mapped[UUID | None] = mapped_column( + payment_customer_id: Mapped[UUID | None] = mapped_column( Uuid, - ForeignKey("users.id", ondelete="set null"), + ForeignKey("customers.id", ondelete="set null"), nullable=True, index=True, ) - """ID of the `User` who made the payment.""" + """ID of the `Customer` who made the payment.""" @declared_attr - def payment_user(cls) -> Mapped["User | None"]: - return relationship("User", lazy="raise") + def payment_customer(cls) -> Mapped["Customer | None"]: + return relationship("Customer", lazy="raise") payment_organization_id: Mapped[UUID | None] = mapped_column( Uuid, diff --git a/server/polar/order/service.py b/server/polar/order/service.py index a6a393bcc8..ff1ffdb34b 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -13,6 +13,7 @@ from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event from polar.checkout.service import checkout as checkout_service from polar.config import settings +from polar.customer.service import customer as customer_service from polar.discount.service import discount as discount_service from polar.email.renderer import get_email_renderer from polar.email.sender import get_email_sender @@ -29,6 +30,7 @@ from polar.logging import Logger from polar.models import ( Checkout, + Customer, Discount, HeldBalance, Order, @@ -62,8 +64,6 @@ from polar.transaction.service.platform_fee import ( platform_fee_transaction as platform_fee_transaction_service, ) -from polar.user.schemas.user import UserSignupAttribution -from polar.user.service.user import user as user_service from polar.webhook.service import webhook as webhook_service from polar.webhook.webhooks import WebhookTypeObject from polar.worker import enqueue_job @@ -177,7 +177,7 @@ async def list( product_id: Sequence[uuid.UUID] | None = None, product_price_type: Sequence[ProductPriceType] | None = None, discount_id: Sequence[uuid.UUID] | None = None, - user_id: Sequence[uuid.UUID] | None = None, + customer_id: Sequence[uuid.UUID] | None = None, pagination: PaginationParams, sorting: list[Sorting[OrderSortProperty]] = [ (OrderSortProperty.created_at, True) @@ -195,10 +195,10 @@ async def list( OrderProductPrice, onclause=Order.product_price_id == OrderProductPrice.id ).options(contains_eager(Order.product_price.of_type(OrderProductPrice))) - OrderUser = aliased(User) + OrderCustomer = aliased(Customer) statement = statement.join( - OrderUser, onclause=Order.user_id == OrderUser.id - ).options(contains_eager(Order.user.of_type(OrderUser))) + OrderCustomer, onclause=Order.customer_id == OrderCustomer.id + ).options(contains_eager(Order.customer.of_type(OrderCustomer))) if organization_id is not None: statement = statement.where(Product.organization_id.in_(organization_id)) @@ -212,8 +212,8 @@ async def list( if discount_id is not None: statement = statement.where(Order.discount_id.in_(discount_id)) - if user_id is not None: - statement = statement.where(Order.user_id.in_(user_id)) + if customer_id is not None: + statement = statement.where(Order.customer_id.in_(customer_id)) order_by_clauses: list[UnaryExpression[Any]] = [] for criterion, is_desc in sorting: @@ -222,8 +222,8 @@ async def list( order_by_clauses.append(clause_function(Order.created_at)) elif criterion == OrderSortProperty.amount: order_by_clauses.append(clause_function(Order.amount)) - elif criterion == OrderSortProperty.user: - order_by_clauses.append(clause_function(OrderUser.email)) + elif criterion == OrderSortProperty.customer: + order_by_clauses.append(clause_function(OrderCustomer.email)) elif criterion == OrderSortProperty.product: order_by_clauses.append(clause_function(Product.name)) elif criterion == OrderSortProperty.discount: @@ -244,7 +244,7 @@ async def get_by_id( self._get_readable_order_statement(auth_subject) .where(Order.id == id) .options( - joinedload(Order.user), + joinedload(Order.customer), joinedload(Order.product_price), joinedload(Order.subscription), joinedload(Order.discount), @@ -354,7 +354,7 @@ async def create_order_from_stripe( if checkout is None: raise CheckoutDoesNotExist(invoice.id, checkout_id) - user: User | None = None + customer: Customer | None = None billing_reason: OrderBillingReason = OrderBillingReason.purchase tax = invoice.tax or 0 @@ -369,7 +369,7 @@ async def create_order_from_stripe( ) if subscription is None: raise SubscriptionDoesNotExist(invoice.id, stripe_subscription_id) - user = await user_service.get(session, subscription.user_id) + customer = await customer_service.get(session, subscription.customer_id) if invoice.billing_reason is not None: try: billing_reason = OrderBillingReason(invoice.billing_reason) @@ -389,8 +389,6 @@ async def create_order_from_stripe( # Create Order order = Order( - # Generate ID upfront for user attribution - id=Order.generate_id(), amount=amount, tax_amount=tax, currency=invoice.currency, @@ -409,36 +407,20 @@ async def create_order_from_stripe( created_at=datetime.fromtimestamp(invoice.created, tz=UTC), ) - # Get or create customer user + # Get or create customer assert invoice.customer is not None stripe_customer_id = get_expandable_id(invoice.customer) - if user is None: - user = await user_service.get_by_stripe_customer_id( + if customer is None: + customer = await customer_service.get_by_stripe_customer_id( session, stripe_customer_id ) - if user is None: - assert invoice.customer_email is not None - signup_attribution = UserSignupAttribution( - intent="purchase", - order=order.id, - ) - if order.subscription: - signup_attribution = UserSignupAttribution( - intent="subscription", - subscription=order.subscription.id, - ) - - user, _ = await user_service.get_by_email_or_create( - session, - invoice.customer_email, - signup_attribution=signup_attribution, + stripe_customer = await stripe_service.get_customer(stripe_customer_id) + if customer is None: + customer = await customer_service.create_from_stripe_customer( + session, stripe_customer, product.organization ) - # Take the chance to update Stripe customer ID and email marketing - user.stripe_customer_id = stripe_customer_id - session.add(user) - - order.user = user + order.customer = customer session.add(order) await session.flush() @@ -472,7 +454,7 @@ async def create_order_from_stripe( enqueue_job( "benefit.enqueue_benefits_grants", task="grant", - user_id=user.id, + customer_id=customer.id, product_id=product.id, order_id=order.id, ) @@ -510,7 +492,7 @@ async def send_admin_notification( notif=PartialNotification( type=NotificationType.maintainer_new_product_sale, payload=MaintainerNewProductSaleNotificationPayload( - customer_name=order.user.email, + customer_name=order.customer.email, product_name=product.name, product_price_amount=order.amount, organization_name=organization.slug, @@ -525,7 +507,7 @@ async def send_confirmation_email( email_sender = get_email_sender() product = order.product - user = order.user + customer = order.customer subject, body = email_renderer.render_from_template( "Your {{ product.name }} order confirmation", "order/confirmation.html", @@ -538,7 +520,7 @@ async def send_confirmation_email( ) email_sender.send_to_user( - to_email_addr=user.email, subject=subject, html_content=body + to_email_addr=customer.email, subject=subject, html_content=body ) async def update_product_benefits_grants( @@ -554,7 +536,7 @@ async def update_product_benefits_grants( enqueue_job( "benefit.enqueue_benefits_grants", task="grant", - user_id=order.user_id, + customer_id=order.customer_id, product_id=product.id, order_id=order.id, ) @@ -650,10 +632,6 @@ async def _send_webhook(self, session: AsyncSession, order: Order) -> None: event: WebhookTypeObject = (WebhookEventType.order_created, order) - # Webhook for customer - await webhook_service.send(session, order.user, event) - - # Webhook for organization organization = await organization_service.get( session, order.product.organization_id ) diff --git a/server/polar/order/sorting.py b/server/polar/order/sorting.py index b0a36e5304..539ff36cf9 100644 --- a/server/polar/order/sorting.py +++ b/server/polar/order/sorting.py @@ -9,7 +9,7 @@ class OrderSortProperty(StrEnum): created_at = "created_at" amount = "amount" - user = "user" + customer = "customer" product = "product" discount = "discount" subscription = "subscription" diff --git a/server/polar/storefront/schemas.py b/server/polar/storefront/schemas.py index be2f8e526b..fed6cac743 100644 --- a/server/polar/storefront/schemas.py +++ b/server/polar/storefront/schemas.py @@ -21,10 +21,7 @@ class ProductStorefront(ProductBase): ) -class Customer(Schema): - public_name: str - github_username: str | None - avatar_url: str | None +class Customer(Schema): ... class Customers(Schema): diff --git a/server/polar/storefront/service.py b/server/polar/storefront/service.py index 2763978daf..0df24a6465 100644 --- a/server/polar/storefront/service.py +++ b/server/polar/storefront/service.py @@ -1,11 +1,10 @@ from collections.abc import Sequence -from sqlalchemy import and_, select +from sqlalchemy import select from sqlalchemy.orm import selectinload from polar.kit.pagination import PaginationParams, paginate -from polar.models import OAuthAccount, Order, Organization, Product, User -from polar.models.user import OAuthPlatform +from polar.models import Customer, Order, Organization, Product from polar.postgres import AsyncSession @@ -34,31 +33,16 @@ async def list_customers( organization: Organization, *, pagination: PaginationParams, - ) -> tuple[Sequence[User], int]: - statement = ( - select(User) - .join( - OAuthAccount, - onclause=and_( - User.id == OAuthAccount.user_id, - OAuthAccount.platform == OAuthPlatform.github, - ), - isouter=True, - ) - .where( - User.id.in_( - select(Order.user_id) - .join(Product, Product.id == Order.product_id) - .where( - Order.deleted_at.is_(None), - Product.organization_id == organization.id, - ) + ) -> tuple[Sequence[Customer], int]: + statement = select(Customer).where( + Customer.id.in_( + select(Order.customer_id) + .join(Product, Product.id == Order.product_id) + .where( + Order.deleted_at.is_(None), + Product.organization_id == organization.id, ) ) - .order_by( - # Put users with a GitHub account first, so we can display their avatar - OAuthAccount.created_at.desc().nulls_last() - ) ) results, count = await paginate(session, statement, pagination=pagination) return results, count diff --git a/server/polar/subscription/service.py b/server/polar/subscription/service.py index 2c479cd41c..fcb5214565 100644 --- a/server/polar/subscription/service.py +++ b/server/polar/subscription/service.py @@ -17,6 +17,7 @@ from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event from polar.checkout.service import checkout as checkout_service from polar.config import settings +from polar.customer.service import customer as customer_service from polar.discount.service import discount as discount_service from polar.email.renderer import get_email_renderer from polar.email.sender import get_email_sender @@ -33,6 +34,7 @@ Benefit, BenefitGrant, Checkout, + Customer, Discount, Organization, Product, @@ -54,9 +56,6 @@ from polar.notifications.service import notifications as notifications_service from polar.organization.service import organization as organization_service from polar.postgres import sql -from polar.posthog import posthog -from polar.user.schemas.user import UserSignupAttribution -from polar.user.service.user import user as user_service from polar.webhook.service import webhook as webhook_service from polar.webhook.webhooks import WebhookTypeObject from polar.worker import enqueue_job @@ -142,7 +141,7 @@ def _from_timestamp(t: int | None) -> datetime | None: class SubscriptionSortProperty(StrEnum): - user = "user" + customer = "customer" status = "status" started_at = "started_at" current_period_end = "current_period_end" @@ -169,7 +168,7 @@ async def get( query = query.options(*options) else: query = query.options( - joinedload(Subscription.user), + joinedload(Subscription.customer), joinedload(Subscription.price), joinedload(Subscription.product).options( selectinload(Product.product_medias), @@ -200,7 +199,7 @@ async def list( ) statement = ( - statement.join(Subscription.user) + statement.join(Subscription.customer) .join(Subscription.price, isouter=True) .join(Subscription.discount, isouter=True) ) @@ -223,8 +222,8 @@ async def list( order_by_clauses: list[UnaryExpression[Any]] = [] for criterion, is_desc in sorting: clause_function = desc if is_desc else asc - if criterion == SubscriptionSortProperty.user: - order_by_clauses.append(clause_function(User.email)) + if criterion == SubscriptionSortProperty.customer: + order_by_clauses.append(clause_function(Customer.email)) if criterion == SubscriptionSortProperty.status: order_by_clauses.append( clause_function( @@ -285,7 +284,7 @@ async def list( ), contains_eager(Subscription.price), contains_eager(Subscription.discount), - contains_eager(Subscription.user), + contains_eager(Subscription.customer), ) results, count = await paginate(session, statement, pagination=pagination) @@ -302,7 +301,7 @@ async def create_arbitrary_subscription( self, session: AsyncSession, *, - user: User, + customer: Customer, product: Product, price: ProductPriceFixed, ) -> Subscription: ... @@ -312,7 +311,7 @@ async def create_arbitrary_subscription( self, session: AsyncSession, *, - user: User, + customer: Customer, product: Product, price: ProductPriceCustom, amount: int, @@ -323,7 +322,7 @@ async def create_arbitrary_subscription( self, session: AsyncSession, *, - user: User, + customer: Customer, product: Product, price: ProductPriceFree, ) -> Subscription: ... @@ -332,7 +331,7 @@ async def create_arbitrary_subscription( self, session: AsyncSession, *, - user: User, + customer: Customer, product: Product, price: ProductPriceFixed | ProductPriceCustom | ProductPriceFree, amount: int | None = None, @@ -357,7 +356,7 @@ async def create_arbitrary_subscription( current_period_start=start, cancel_at_period_end=False, started_at=start, - user=user, + customer=customer, product=product, price=price, ) @@ -415,18 +414,14 @@ async def create_subscription_from_stripe( statement = ( select(Subscription) .where(Subscription.id == uuid.UUID(existing_subscription_id)) - .options(joinedload(Subscription.user)) + .options(joinedload(Subscription.customer)) ) result = await session.execute(statement) subscription = result.unique().scalar_one_or_none() # New subscription if subscription is None: - subscription = Subscription( - # Generate ID upfront for user attribution - id=Subscription.generate_id(), - user=None, - ) + subscription = Subscription() subscription.stripe_subscription_id = stripe_subscription.id subscription.status = SubscriptionStatus(stripe_subscription.status) @@ -462,43 +457,22 @@ async def create_subscription_from_stripe( subscription.set_started_at() - customer_id = get_expandable_id(stripe_subscription.customer) - customer = await stripe_service.get_customer(customer_id) - customer_email = cast(str, customer.email) - - # Take user from existing subscription, or get it from metadata - user_id = stripe_subscription.metadata.get("user_id") - user = cast(User | None, subscription.user) - if user is None: - if user_id is not None: - user = await user_service.get(session, uuid.UUID(user_id)) - if user is None: - user, _ = await user_service.get_by_email_or_create( - session, - customer_email, - signup_attribution=UserSignupAttribution( - intent="subscription", - subscription=subscription.id, - ), + # Take customer from existing subscription, or retrieve it from Stripe Customer ID + if subscription.customer is None: + stripe_customer_id = get_expandable_id(stripe_subscription.customer) + customer = await customer_service.get_by_stripe_customer_id( + session, stripe_customer_id + ) + if customer is None: + stripe_customer = await stripe_service.get_customer(stripe_customer_id) + customer = await customer_service.create_from_stripe_customer( + session, stripe_customer, subscription_tier_org ) - - subscription.user = user - - # Take the chance to update Stripe customer ID and email marketing - user.stripe_customer_id = customer_id - session.add(user) + subscription.customer = customer session.add(subscription) await session.flush() - posthog.user_event( - user, - "subscriptions", - "subscription", - "create", - {"subscription_id": subscription.id}, - ) - # Notify checkout channel that a subscription has been created from it if checkout is not None: await publish_checkout_event( @@ -585,17 +559,6 @@ async def update_subscription_from_stripe( session.add(subscription) - if subscription.cancel_at_period_end or subscription.ended_at: - user = await user_service.get(session, subscription.user_id) - if user: - posthog.user_event( - user, - "subscriptions", - "subscription", - "cancel", - {"subscription_id": subscription.id}, - ) - await self.enqueue_benefits_grants(session, subscription) await self._after_subscription_updated( @@ -654,7 +617,7 @@ async def _send_new_subscription_notification( notif=PartialNotification( type=NotificationType.maintainer_new_paid_subscription, payload=MaintainerNewPaidSubscriptionNotificationPayload( - subscriber_name=subscription.user.email, + subscriber_name=subscription.customer.email, tier_name=product.name, tier_price_amount=subscription.amount, tier_price_recurring_interval=price.recurring_interval, @@ -679,11 +642,6 @@ async def _send_webhook( event = cast(WebhookTypeObject, (event_type, full_subscription)) - # subscription events for subscribing user - if subscribing_user := await user_service.get(session, subscription.user_id): - await webhook_service.send(session, target=subscribing_user, we=event) - - # subscribed to org if tier := await product_service.get_loaded(session, subscription.product_id): if subscribed_to_org := await organization_service.get( session, tier.organization_id @@ -700,12 +658,11 @@ async def enqueue_benefits_grants( return task = "grant" if subscription.active else "revoke" - user_id = subscription.user_id enqueue_job( "benefit.enqueue_benefits_grants", task=task, - user_id=user_id, + customer_id=subscription.customer_id, product_id=product.id, subscription_id=subscription.id, ) @@ -731,7 +688,6 @@ async def send_confirmation_email( session, product.organization_id ) assert featured_organization is not None - user = subscription.user subject, body = email_renderer.render_from_template( "Your {{ product.name }} subscription", @@ -748,7 +704,9 @@ async def send_confirmation_email( ) email_sender.send_to_user( - to_email_addr=user.email, subject=subject, html_content=body + to_email_addr=subscription.customer.email, + subject=subject, + html_content=body, ) async def send_cancellation_email( @@ -762,7 +720,6 @@ async def send_cancellation_email( session, product.organization_id ) assert featured_organization is not None - user = subscription.user subject, body = email_renderer.render_from_template( "Your {{ product.name }} subscription cancellation", @@ -780,7 +737,9 @@ async def send_cancellation_email( ) email_sender.send_to_user( - to_email_addr=user.email, subject=subject, html_content=body + to_email_addr=subscription.customer.email, + subject=subject, + html_content=body, ) def _get_readable_subscriptions_statement( diff --git a/server/polar/user/schemas/downloadables.py b/server/polar/user/schemas/downloadables.py index dcfc54fc88..5e803f0837 100644 --- a/server/polar/user/schemas/downloadables.py +++ b/server/polar/user/schemas/downloadables.py @@ -22,13 +22,13 @@ class DownloadableRead(Schema): class DownloadableCreate(Schema): file_id: UUID4 - user_id: UUID4 + customer_id: UUID4 benefit_id: BenefitID status: DownloadableStatus class DownloadableUpdate(Schema): file_id: UUID4 - user_id: UUID4 + customer_id: UUID4 benefit_id: BenefitID status: DownloadableStatus diff --git a/server/polar/user/service/downloadables.py b/server/polar/user/service/downloadables.py index 72d81a7f7c..f5cc2110fd 100644 --- a/server/polar/user/service/downloadables.py +++ b/server/polar/user/service/downloadables.py @@ -17,7 +17,7 @@ from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceService from polar.kit.utils import utc_now -from polar.models import Benefit, User +from polar.models import Benefit, Customer, User from polar.models.downloadable import Downloadable, DownloadableStatus from polar.models.file import File from polar.postgres import AsyncSession, sql @@ -61,7 +61,7 @@ async def get_list( async def grant_for_benefit_file( self, session: AsyncSession, - user: User, + customer: Customer, benefit_id: UUID, file_id: UUID, ) -> Downloadable | None: @@ -70,7 +70,7 @@ async def grant_for_benefit_file( log.info( "downloadables.grant.file_not_found", file_id=file_id, - user_id=user.id, + customer_id=customer.id, benefit_id=benefit_id, granted=False, ) @@ -78,7 +78,7 @@ async def grant_for_benefit_file( create_schema = DownloadableCreate( file_id=file.id, - user_id=user.id, + customer_id=customer.id, benefit_id=benefit_id, status=DownloadableStatus.granted, ) @@ -87,7 +87,7 @@ async def grant_for_benefit_file( create_schemas=[create_schema], constraints=[ Downloadable.file_id, - Downloadable.user_id, + Downloadable.customer_id, Downloadable.benefit_id, ], mutable_keys={ @@ -102,7 +102,7 @@ async def grant_for_benefit_file( log.info( "downloadables.grant", file_id=file.id, - user_id=user.id, + customer_id=customer.id, downloadables_id=instance.id, benefit_id=benefit_id, granted=True, @@ -112,13 +112,13 @@ async def grant_for_benefit_file( async def revoke_for_benefit( self, session: AsyncSession, - user: User, + customer: Customer, benefit_id: UUID, ) -> None: statement = ( sql.update(Downloadable) .where( - Downloadable.user_id == user.id, + Downloadable.customer_id == customer.id, Downloadable.benefit_id == benefit_id, Downloadable.status == DownloadableStatus.granted, Downloadable.deleted_at.is_(None), @@ -130,7 +130,7 @@ async def revoke_for_benefit( ) log.info( "downloadables.revoked", - user_id=user.id, + customer_id=customer.id, benefit_id=benefit_id, ) await session.execute(statement) diff --git a/server/tests/checkout/test_endpoints.py b/server/tests/checkout/test_endpoints.py index 739a297bcf..d65dde8fc4 100644 --- a/server/tests/checkout/test_endpoints.py +++ b/server/tests/checkout/test_endpoints.py @@ -9,8 +9,8 @@ from polar.auth.scope import Scope from polar.checkout.service import checkout as checkout_service -from polar.checkout.tax import calculate_tax from polar.integrations.stripe.service import StripeService +from polar.kit.tax import calculate_tax from polar.models import Checkout, Product, UserOrganization from polar.postgres import AsyncSession from tests.fixtures.auth import AuthSubjectFixture diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index 1d0644ad02..8d0dccb74f 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -31,13 +31,13 @@ PaymentRequired, ) from polar.checkout.service import checkout as checkout_service -from polar.checkout.tax import IncompleteTaxLocation, TaxIDFormat, calculate_tax from polar.discount.service import discount as discount_service from polar.enums import PaymentProcessor from polar.exceptions import PolarRequestValidationError from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService from polar.kit.address import Address +from polar.kit.tax import IncompleteTaxLocation, TaxIDFormat, calculate_tax from polar.kit.utils import utc_now from polar.locker import Locker from polar.models import ( diff --git a/server/tests/checkout/test_tax.py b/server/tests/checkout/test_tax.py index 5d56ed7916..90512cd080 100644 --- a/server/tests/checkout/test_tax.py +++ b/server/tests/checkout/test_tax.py @@ -1,7 +1,7 @@ import pytest from pydantic_extra_types.country import CountryAlpha2 -from polar.checkout.tax import TaxID, TaxIDFormat, validate_tax_id +from polar.kit.tax import TaxID, TaxIDFormat, validate_tax_id @pytest.mark.parametrize( From f104f682a178f5850e577729facd54c45de36a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 14:02:17 +0100 Subject: [PATCH 02/47] server/customer: implement CRUD API --- server/polar/auth/scope.py | 5 + server/polar/customer/auth.py | 24 +++++ server/polar/customer/endpoints.py | 144 +++++++++++++++++++++++++++++ server/polar/customer/schemas.py | 36 ++++++++ server/polar/customer/service.py | 119 +++++++++++++++++++++++- server/polar/customer/sorting.py | 18 ++++ 6 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 server/polar/customer/auth.py create mode 100644 server/polar/customer/endpoints.py create mode 100644 server/polar/customer/schemas.py create mode 100644 server/polar/customer/sorting.py diff --git a/server/polar/auth/scope.py b/server/polar/auth/scope.py index 19a4efd5bc..aa756639a3 100644 --- a/server/polar/auth/scope.py +++ b/server/polar/auth/scope.py @@ -41,6 +41,9 @@ class Scope(StrEnum): subscriptions_read = "subscriptions:read" subscriptions_write = "subscriptions:write" + customers_read = "customers:read" + customers_write = "customers:write" + orders_read = "orders:read" metrics_read = "metrics:read" @@ -106,6 +109,8 @@ def __get_pydantic_json_schema__( Scope.subscriptions_write: ( "Create or modify subscriptions made on your organizations" ), + Scope.customers_read: "Read customers", + Scope.customers_write: "Create or modify customers", Scope.orders_read: "Read orders made on your organizations", Scope.metrics_read: "Read metrics", Scope.webhooks_read: "Read webhooks", diff --git a/server/polar/customer/auth.py b/server/polar/customer/auth.py new file mode 100644 index 0000000000..d174a47259 --- /dev/null +++ b/server/polar/customer/auth.py @@ -0,0 +1,24 @@ +from typing import Annotated + +from fastapi import Depends + +from polar.auth.dependencies import Authenticator +from polar.auth.models import AuthSubject, User +from polar.auth.scope import Scope +from polar.models.organization import Organization + +_CustomerRead = Authenticator( + required_scopes={ + Scope.web_default, + Scope.customers_read, + Scope.customers_write, + }, + allowed_subjects={User, Organization}, +) +CustomerRead = Annotated[AuthSubject[User | Organization], Depends(_CustomerRead)] + +_CustomerWrite = Authenticator( + required_scopes={Scope.web_default, Scope.customers_write}, + allowed_subjects={User, Organization}, +) +CustomerWrite = Annotated[AuthSubject[User | Organization], Depends(_CustomerWrite)] diff --git a/server/polar/customer/endpoints.py b/server/polar/customer/endpoints.py new file mode 100644 index 0000000000..b10398bde0 --- /dev/null +++ b/server/polar/customer/endpoints.py @@ -0,0 +1,144 @@ +from typing import Annotated + +from fastapi import Depends, Path, Query +from pydantic import UUID4 + +from polar.authz.service import Authz +from polar.exceptions import ResourceNotFound +from polar.kit.pagination import ListResource, PaginationParamsQuery +from polar.kit.schemas import MultipleQueryFilter +from polar.models import Customer +from polar.openapi import APITag +from polar.organization.schemas import OrganizationID +from polar.postgres import AsyncSession, get_db_session +from polar.routing import APIRouter + +from . import auth, sorting +from .schemas import Customer as CustomerSchema +from .schemas import CustomerCreate, CustomerUpdate +from .service import customer as customer_service + +router = APIRouter( + prefix="/customers", tags=["customers", APITag.documented, APITag.featured] +) + + +CustomerID = Annotated[UUID4, Path(description="The customer ID.")] +CustomerNotFound = { + "description": "Customer not found.", + "model": ResourceNotFound.schema(), +} + + +@router.get("/", summary="List Customers", response_model=ListResource[CustomerSchema]) +async def list( + auth_subject: auth.CustomerRead, + pagination: PaginationParamsQuery, + sorting: sorting.ListSorting, + organization_id: MultipleQueryFilter[OrganizationID] | None = Query( + None, title="OrganizationID Filter", description="Filter by organization ID." + ), + session: AsyncSession = Depends(get_db_session), +) -> ListResource[CustomerSchema]: + """List customers.""" + results, count = await customer_service.list( + session, + auth_subject, + organization_id=organization_id, + pagination=pagination, + sorting=sorting, + ) + + return ListResource.from_paginated_results( + [CustomerSchema.model_validate(result) for result in results], + count, + pagination, + ) + + +@router.get( + "/{id}", + summary="Get Customer", + response_model=CustomerSchema, + responses={404: CustomerNotFound}, +) +async def get( + id: CustomerID, + auth_subject: auth.CustomerRead, + session: AsyncSession = Depends(get_db_session), +) -> Customer: + """Get a customer by ID.""" + customer = await customer_service.get_by_id(session, auth_subject, id) + + if customer is None: + raise ResourceNotFound() + + return customer + + +@router.post( + "/", + response_model=CustomerSchema, + status_code=201, + summary="Create Customer", + responses={201: {"description": "Customer created."}}, +) +async def create( + customer_create: CustomerCreate, + auth_subject: auth.CustomerWrite, + session: AsyncSession = Depends(get_db_session), + authz: Authz = Depends(Authz.authz), +) -> Customer: + """Create a customer.""" + return await customer_service.create(session, authz, customer_create, auth_subject) + + +@router.patch( + "/{id}", + response_model=CustomerSchema, + summary="Update Customer", + responses={ + 200: {"description": "Customer updated."}, + 404: CustomerNotFound, + }, +) +async def update( + id: CustomerID, + customer_update: CustomerUpdate, + auth_subject: auth.CustomerWrite, + session: AsyncSession = Depends(get_db_session), +) -> Customer: + """Update a customer.""" + customer = await customer_service.get_by_id(session, auth_subject, id) + + if customer is None: + raise ResourceNotFound() + + return await customer_service.update(session, customer, customer_update) + + +@router.delete( + "/{id}", + status_code=204, + summary="Delete Customer", + responses={ + 204: {"description": "Customer deleted."}, + 404: CustomerNotFound, + }, +) +async def delete( + id: CustomerID, + auth_subject: auth.CustomerWrite, + session: AsyncSession = Depends(get_db_session), +) -> None: + """ + Delete a customer. + + Immediately cancels any active subscriptions and revokes any active benefits. + """ + customer = await customer_service.get_by_id(session, auth_subject, id) + + if customer is None: + raise ResourceNotFound() + + await customer_service.delete(session, customer) diff --git a/server/polar/customer/schemas.py b/server/polar/customer/schemas.py new file mode 100644 index 0000000000..66080d8e2b --- /dev/null +++ b/server/polar/customer/schemas.py @@ -0,0 +1,36 @@ +from pydantic import UUID4, Field + +from polar.kit.address import Address +from polar.kit.schemas import EmailStrDNS, IDSchema, Schema, TimestampedSchema +from polar.kit.tax import TaxID +from polar.organization.schemas import OrganizationID + + +class CustomerCreate(Schema): + email: EmailStrDNS + name: str | None = None + billing_address: Address | None = None + tax_id: TaxID | None = None + organization_id: OrganizationID | None = Field( + default=None, + description=( + "The ID of the organization owning the customer. " + "**Required unless you use an organization token.**" + ), + ) + + +class CustomerUpdate(Schema): + email: EmailStrDNS | None = None + name: str | None = None + billing_address: Address | None = None + tax_id: TaxID | None = None + + +class Customer(IDSchema, TimestampedSchema): + email: str + email_verified: bool + name: str | None + billing_address: Address | None + tax_id: TaxID | None + organization_id: UUID4 diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 4c2ba415cd..8057cdea30 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -1,12 +1,105 @@ -from sqlalchemy import select +import uuid +from collections.abc import Sequence +from typing import Any + +from sqlalchemy import Select, UnaryExpression, asc, desc, select from stripe import Customer as StripeCustomer +from polar.auth.models import AuthSubject, is_organization, is_user +from polar.authz.service import AccessType, Authz +from polar.exceptions import NotPermitted +from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader -from polar.models import Customer, Organization +from polar.kit.sorting import Sorting +from polar.models import Customer, Organization, User, UserOrganization +from polar.organization.resolver import get_payload_organization from polar.postgres import AsyncSession +from .schemas import CustomerCreate, CustomerUpdate +from .sorting import CustomerSortProperty + class CustomerService(ResourceServiceReader[Customer]): + async def list( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + *, + organization_id: Sequence[uuid.UUID] | None = None, + pagination: PaginationParams, + sorting: list[Sorting[CustomerSortProperty]] = [ + (CustomerSortProperty.created_at, True) + ], + ) -> tuple[Sequence[Customer], int]: + statement = self._get_readable_customer_statement(auth_subject) + + if organization_id is not None: + statement = statement.where(Customer.organization_id.in_(organization_id)) + + order_by_clauses: list[UnaryExpression[Any]] = [] + for criterion, is_desc in sorting: + clause_function = desc if is_desc else asc + if criterion == CustomerSortProperty.created_at: + order_by_clauses.append(clause_function(Customer.created_at)) + elif criterion == CustomerSortProperty.email: + order_by_clauses.append(clause_function(Customer.email)) + elif criterion == CustomerSortProperty.name: + order_by_clauses.append(clause_function(Customer.name)) + statement = statement.order_by(*order_by_clauses) + + return await paginate(session, statement, pagination=pagination) + + async def get_by_id( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + id: uuid.UUID, + ) -> Customer | None: + statement = self._get_readable_customer_statement(auth_subject).where( + Customer.id == id + ) + result = await session.execute(statement) + return result.unique().scalar_one_or_none() + + async def create( + self, + session: AsyncSession, + authz: Authz, + customer_create: CustomerCreate, + auth_subject: AuthSubject[User | Organization], + ) -> Customer: + subject = auth_subject.subject + + organization = await get_payload_organization( + session, auth_subject, customer_create + ) + if not await authz.can(subject, AccessType.write, organization): + raise NotPermitted() + + customer = Customer( + organization=organization, + **customer_create.model_dump(exclude={"organization_id"}), + ) + + session.add(customer) + return customer + + async def update( + self, session: AsyncSession, customer: Customer, customer_update: CustomerUpdate + ) -> Customer: + for attr, value in customer_update.model_dump(exclude_unset=True).items(): + setattr(customer, attr, value) + + session.add(customer) + return customer + + async def delete(self, session: AsyncSession, customer: Customer) -> Customer: + # TODO: cancel subscriptions, revoke benefits, etc. + + customer.set_deleted_at() + session.add(customer) + return customer + async def get_by_id_and_organization( self, session: AsyncSession, id: str, organization: Organization ) -> Customer | None: @@ -47,5 +140,27 @@ async def create_from_stripe_customer( session.add(customer) return customer + def _get_readable_customer_statement( + self, auth_subject: AuthSubject[User | Organization] + ) -> Select[tuple[Customer]]: + statement = select(Customer).where(Customer.deleted_at.is_(None)) + + if is_user(auth_subject): + user = auth_subject.subject + statement = statement.where( + Customer.organization_id.in_( + select(UserOrganization.organization_id).where( + UserOrganization.user_id == user.id, + UserOrganization.deleted_at.is_(None), + ) + ) + ) + elif is_organization(auth_subject): + statement = statement.where( + Customer.organization_id == auth_subject.subject.id, + ) + + return statement + customer = CustomerService(Customer) diff --git a/server/polar/customer/sorting.py b/server/polar/customer/sorting.py new file mode 100644 index 0000000000..b49d11b728 --- /dev/null +++ b/server/polar/customer/sorting.py @@ -0,0 +1,18 @@ +from enum import StrEnum +from typing import Annotated + +from fastapi import Depends + +from polar.kit.sorting import Sorting, SortingGetter + + +class CustomerSortProperty(StrEnum): + created_at = "created_at" + email = "email" + customer_name = "name" # `name` is a reserved word, so we use `customer_name` + + +ListSorting = Annotated[ + list[Sorting[CustomerSortProperty]], + Depends(SortingGetter(CustomerSortProperty, ["-created_at"])), +] From 6af56cda74c5c13165a88f55677796714d94f8c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 14:13:23 +0100 Subject: [PATCH 03/47] server: fix order and subscription API to work with customers --- server/polar/order/endpoints.py | 6 +++--- server/polar/subscription/endpoints.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/polar/order/endpoints.py b/server/polar/order/endpoints.py index 8053cddeb2..5a6bcad4ae 100644 --- a/server/polar/order/endpoints.py +++ b/server/polar/order/endpoints.py @@ -50,8 +50,8 @@ async def list( discount_id: MultipleQueryFilter[UUID4] | None = Query( None, title="DiscountID Filter", description="Filter by discount ID." ), - user_id: MultipleQueryFilter[UUID4] | None = Query( - None, title="UserID Filter", description="Filter by customer's user ID." + customer_id: MultipleQueryFilter[UUID4] | None = Query( + None, title="CustomerID Filter", description="Filter by customer ID." ), session: AsyncSession = Depends(get_db_session), ) -> ListResource[OrderSchema]: @@ -63,7 +63,7 @@ async def list( product_id=product_id, product_price_type=product_price_type, discount_id=discount_id, - user_id=user_id, + customer_id=customer_id, pagination=pagination, sorting=sorting, ) diff --git a/server/polar/subscription/endpoints.py b/server/polar/subscription/endpoints.py index 2e7b97eec5..a61048c364 100644 --- a/server/polar/subscription/endpoints.py +++ b/server/polar/subscription/endpoints.py @@ -108,7 +108,7 @@ async def create_csv() -> AsyncGenerator[str, None]: for sub in subscribers: yield csv_writer.getrow( ( - sub.user.email, + sub.customer.email, sub.created_at.isoformat(), "true" if sub.active else "false", sub.product.name, From 9937a923a7283e21be96f254d220fa03af32c562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 14:42:17 +0100 Subject: [PATCH 04/47] server/checkout: fix logic and tests related to customers changes --- server/polar/checkout/endpoints.py | 2 +- server/polar/checkout/service.py | 17 ++--- server/tests/checkout/test_service.py | 94 +++++++------------------ server/tests/fixtures/random_objects.py | 45 +++++++++--- 4 files changed, 71 insertions(+), 87 deletions(-) diff --git a/server/polar/checkout/endpoints.py b/server/polar/checkout/endpoints.py index e691a69647..dc3982eaf5 100644 --- a/server/polar/checkout/endpoints.py +++ b/server/polar/checkout/endpoints.py @@ -180,7 +180,7 @@ async def client_create( """Create a checkout session from a client. Suitable to build checkout links.""" ip_address = request.client.host if request.client else None return await checkout_service.client_create( - session, checkout_create, auth_subject, ip_geolocation_client, ip_address + session, checkout_create, ip_geolocation_client, ip_address ) diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index ab92f740b0..71e4ee83eb 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import contains_eager, joinedload, selectinload from polar.auth.models import ( - Anonymous, AuthSubject, is_organization, is_user, @@ -211,8 +210,12 @@ async def get_by_id( auth_subject: AuthSubject[User | Organization], id: uuid.UUID, ) -> Checkout | None: - statement = self._get_readable_checkout_statement(auth_subject).where( - Checkout.id == id + statement = ( + self._get_readable_checkout_statement(auth_subject) + .where(Checkout.id == id) + .options( + joinedload(Checkout.customer), + ) ) result = await session.execute(statement) return result.unique().scalar_one_or_none() @@ -375,7 +378,6 @@ async def client_create( self, session: AsyncSession, checkout_create: CheckoutCreatePublic, - auth_subject: AuthSubject[User | Anonymous], ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None, ip_address: str | None = None, ) -> Checkout: @@ -708,8 +710,6 @@ async def _confirm_inner( if len(errors) > 0: raise PolarRequestValidationError(errors) - assert checkout.customer_email is not None - if checkout.payment_processor == PaymentProcessor.stripe: customer = await self._create_or_update_customer(session, checkout) checkout.customer = customer @@ -1034,11 +1034,12 @@ async def get_by_client_secret( ) .join(Checkout.product) .options( + joinedload(Checkout.customer), contains_eager(Checkout.product).options( joinedload(Product.organization), selectinload(Product.product_medias), selectinload(Product.attached_custom_fields), - ) + ), ) ) result = await session.execute(statement) @@ -1528,7 +1529,6 @@ async def _create_or_update_customer( ) stripe_customer_id = customer.stripe_customer_id - if stripe_customer_id is None: create_params: stripe_lib.Customer.CreateParams = {"email": customer.email} if checkout.customer_name is not None: @@ -1554,6 +1554,7 @@ async def _create_or_update_customer( else None, **update_params, ) + customer.stripe_customer_id = stripe_customer_id session.add(customer) await session.flush() diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index 8d0dccb74f..ef04a46f6c 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -11,7 +11,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm import joinedload -from polar.auth.models import Anonymous, AuthMethod, AuthSubject +from polar.auth.models import Anonymous, AuthSubject from polar.checkout.schemas import ( CheckoutConfirmStripe, CheckoutCreatePublic, @@ -42,6 +42,7 @@ from polar.locker import Locker from polar.models import ( Checkout, + Customer, Discount, Organization, Product, @@ -63,10 +64,10 @@ create_checkout, create_checkout_link, create_custom_field, + create_customer, create_product, create_product_price_fixed, create_subscription, - create_user, ) @@ -149,10 +150,10 @@ async def checkout_confirmed_recurring_upgrade( save_fixture: SaveFixture, product: Product, product_recurring_free_price: Product, - user_second: User, + customer: Customer, ) -> Checkout: subscription = await create_subscription( - save_fixture, product=product_recurring_free_price, user=user_second + save_fixture, product=product_recurring_free_price, customer=customer ) return await create_checkout( save_fixture, @@ -466,10 +467,10 @@ async def test_invalid_paid_subscription( auth_subject: AuthSubject[User | Organization], user_organization: UserOrganization, product: Product, - user_second: User, + customer: Customer, ) -> None: subscription = await create_subscription( - save_fixture, product=product, user=user_second + save_fixture, product=product, customer=customer ) price = product.prices[0] @@ -738,10 +739,10 @@ async def test_valid_subscription_upgrade( user_organization: UserOrganization, product: Product, product_recurring_free_price: Product, - user_second: User, + customer: Customer, ) -> None: subscription = await create_subscription( - save_fixture, product=product_recurring_free_price, user=user_second + save_fixture, product=product_recurring_free_price, customer=customer ) price = product.prices[0] @@ -990,7 +991,6 @@ async def test_not_existing_price( CheckoutCreatePublic( product_price_id=uuid.uuid4(), ), - auth_subject, ) async def test_archived_price( @@ -1008,9 +1008,7 @@ async def test_archived_price( ) with pytest.raises(PolarRequestValidationError): await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, + session, CheckoutCreatePublic(product_price_id=price.id) ) async def test_archived_product( @@ -1028,7 +1026,6 @@ async def test_archived_product( CheckoutCreatePublic( product_price_id=product_one_time.prices[0].id, ), - auth_subject, ) async def test_valid_fixed_price( @@ -1040,9 +1037,7 @@ async def test_valid_fixed_price( price = product_one_time.prices[0] assert isinstance(price, ProductPriceFixed) checkout = await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, + session, CheckoutCreatePublic(product_price_id=price.id) ) assert checkout.product_price == price @@ -1059,9 +1054,7 @@ async def test_valid_free_price( price = product_one_time_free_price.prices[0] assert isinstance(price, ProductPriceFree) checkout = await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, + session, CheckoutCreatePublic(product_price_id=price.id) ) assert checkout.product_price == price @@ -1080,9 +1073,7 @@ async def test_valid_custom_price( price.preset_amount = 4242 checkout = await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, + session, CheckoutCreatePublic(product_price_id=price.id) ) assert checkout.product_price == price @@ -1090,45 +1081,6 @@ async def test_valid_custom_price( assert checkout.amount == price.preset_amount assert checkout.currency == price.price_currency - @pytest.mark.auth( - AuthSubjectFixture(subject="user", method=AuthMethod.COOKIE), - AuthSubjectFixture(subject="user", method=AuthMethod.OAUTH2_ACCESS_TOKEN), - ) - async def test_valid_direct_user( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - product_one_time: Product, - ) -> None: - price = product_one_time.prices[0] - assert isinstance(price, ProductPriceFixed) - checkout = await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, - ) - assert checkout.customer == auth_subject.subject - assert checkout.customer_email == auth_subject.subject.email - - @pytest.mark.auth( - AuthSubjectFixture(subject="user", method=AuthMethod.PERSONAL_ACCESS_TOKEN), - ) - async def test_valid_indirect_user( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - product_one_time: Product, - ) -> None: - price = product_one_time.prices[0] - assert isinstance(price, ProductPriceFixed) - checkout = await checkout_service.client_create( - session, - CheckoutCreatePublic(product_price_id=price.id), - auth_subject, - ) - - assert checkout.customer is None - async def test_valid_from_legacy_checkout_link( self, session: AsyncSession, @@ -1142,7 +1094,6 @@ async def test_valid_from_legacy_checkout_link( CheckoutCreatePublic( product_price_id=price.id, from_legacy_checkout_link=True ), - auth_subject, ) assert checkout.product_price == price @@ -1569,11 +1520,11 @@ async def test_ignore_email_update_if_customer_set( self, session: AsyncSession, save_fixture: SaveFixture, - user: User, + customer: Customer, checkout_one_time_fixed: Checkout, ) -> None: - checkout_one_time_fixed.customer = user - checkout_one_time_fixed.customer_email = user.email + checkout_one_time_fixed.customer = customer + checkout_one_time_fixed.customer_email = customer.email await save_fixture(checkout_one_time_fixed) checkout = await checkout_service.update( @@ -1582,7 +1533,7 @@ async def test_ignore_email_update_if_customer_set( CheckoutUpdate(customer_email="updatedemail@example.com"), ) - assert checkout.customer_email == user.email + assert checkout.customer_email == customer.email async def test_valid_metadata( self, @@ -1997,11 +1948,16 @@ async def test_valid_stripe_existing_customer( stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, + organization: Organization, checkout_one_time_fixed: Checkout, ) -> None: - user = await create_user(save_fixture, stripe_customer_id="STRIPE_CUSTOMER_ID") - checkout_one_time_fixed.customer = user - checkout_one_time_fixed.customer_email = user.email + customer = await create_customer( + save_fixture, + organization=organization, + stripe_customer_id="CHECKOUT_CUSTOMER_ID", + ) + checkout_one_time_fixed.customer = customer + checkout_one_time_fixed.customer_email = customer.email await save_fixture(checkout_one_time_fixed) stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 77c188e87f..205f67b23b 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -22,6 +22,7 @@ Benefit, Checkout, CheckoutLink, + Customer, CustomField, Discount, DiscountProduct, @@ -832,11 +833,29 @@ async def discount_percentage_100( ) +async def create_customer( + save_fixture: SaveFixture, + *, + organization: Organization, + email: str = "customer@example.com", + email_verified: bool = False, + stripe_customer_id: str = "STRIPE_CUSTOMER_ID", +) -> Customer: + customer = Customer( + email=email, + email_verified=email_verified, + stripe_customer_id=stripe_customer_id, + organization=organization, + ) + await save_fixture(customer) + return customer + + async def create_order( save_fixture: SaveFixture, *, product: Product, - user: User, + customer: Customer, product_price: ProductPrice | None = None, subscription: Subscription | None = None, amount: int = 1000, @@ -853,7 +872,7 @@ async def create_order( currency="usd", billing_reason=billing_reason, stripe_invoice_id=stripe_invoice_id, - user=user, + customer=customer, product=product, product_price=product_price if product_price is not None @@ -910,7 +929,7 @@ async def create_subscription( *, product: Product, price: ProductPrice | None = None, - user: User, + customer: Customer, status: SubscriptionStatus = SubscriptionStatus.incomplete, started_at: datetime | None = None, ended_at: datetime | None = None, @@ -940,7 +959,7 @@ async def create_subscription( cancel_at_period_end=False, started_at=started_at, ended_at=ended_at, - user=user, + customer=customer, product=product, price=price, discount=discount, @@ -954,7 +973,7 @@ async def create_active_subscription( *, product: Product, price: ProductPrice | None = None, - user: User, + customer: Customer, organization: Organization | None = None, started_at: datetime | None = None, ended_at: datetime | None = None, @@ -964,7 +983,7 @@ async def create_active_subscription( save_fixture, product=product, price=price, - user=user, + customer=customer, status=SubscriptionStatus.active, started_at=started_at or utc_now(), ended_at=ended_at, @@ -1113,7 +1132,7 @@ async def create_checkout( amount: int | None = None, tax_amount: int | None = None, currency: str | None = None, - customer: User | None = None, + customer: Customer | None = None, subscription: Subscription | None = None, discount: Discount | None = None, ) -> Checkout: @@ -1245,13 +1264,21 @@ async def organization_second_members( return users +@pytest_asyncio.fixture +async def customer( + save_fixture: SaveFixture, + organization: Organization, +) -> Customer: + return await create_customer(save_fixture, organization=organization) + + @pytest_asyncio.fixture async def subscription( save_fixture: SaveFixture, product: Product, - user: User, + customer: Customer, ) -> Subscription: - return await create_subscription(save_fixture, product=product, user=user) + return await create_subscription(save_fixture, product=product, customer=customer) async def create_benefit_grant( From 3fb5129fb047f6e7bf825f1034f1512fc72890c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 15:21:57 +0100 Subject: [PATCH 05/47] server/order: fix logic and tests related to customers changes --- server/polar/order/schemas.py | 26 +++-- server/polar/order/service.py | 2 +- server/polar/subscription/schemas.py | 22 ++-- server/polar/transaction/service/payment.py | 12 +-- server/polar/webhook/webhooks.py | 4 +- server/tests/order/test_endpoints.py | 6 +- server/tests/order/test_service.py | 109 ++++++++------------ server/tests/transaction/conftest.py | 19 ++-- 8 files changed, 97 insertions(+), 103 deletions(-) diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index 28e41c44ea..e2bebab5ad 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -4,12 +4,11 @@ from pydantic import UUID4, Field from polar.custom_field.data import CustomFieldDataOutputMixin -from polar.discount.schemas import ( - DiscountMinimal, -) +from polar.discount.schemas import DiscountMinimal from polar.kit.address import Address from polar.kit.metadata import MetadataOutputMixin from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema +from polar.kit.tax import TaxID from polar.models.order import OrderBillingReason from polar.product.schemas import ProductBase, ProductPrice from polar.subscription.schemas import SubscriptionBase @@ -24,7 +23,10 @@ class OrderBase( billing_reason: OrderBillingReason billing_address: Address | None - user_id: UUID4 + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) + customer_id: UUID4 product_id: UUID4 product_price_id: UUID4 discount_id: UUID4 | None @@ -39,12 +41,13 @@ def get_amount_display(self) -> str: )}" -class OrderUser(Schema): - id: UUID4 +class OrderCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): email: str - public_name: str - github_username: str | None - avatar_url: str | None + email_verified: bool + name: str | None + billing_address: Address | None + tax_id: TaxID | None + organization_id: UUID4 class OrderProduct(ProductBase): ... @@ -57,7 +60,10 @@ class OrderSubscription(SubscriptionBase, MetadataOutputMixin): ... class Order(OrderBase): - user: OrderUser + customer: OrderCustomer + user: OrderCustomer = Field( + validation_alias="customer", deprecated="Use `customer`." + ) product: OrderProduct product_price: ProductPrice discount: OrderDiscount | None diff --git a/server/polar/order/service.py b/server/polar/order/service.py index ff1ffdb34b..47cf0862d7 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -414,8 +414,8 @@ async def create_order_from_stripe( customer = await customer_service.get_by_stripe_customer_id( session, stripe_customer_id ) - stripe_customer = await stripe_service.get_customer(stripe_customer_id) if customer is None: + stripe_customer = await stripe_service.get_customer(stripe_customer_id) customer = await customer_service.create_from_stripe_customer( session, stripe_customer, product.organization ) diff --git a/server/polar/subscription/schemas.py b/server/polar/subscription/schemas.py index 33724b4691..c4866fa36b 100644 --- a/server/polar/subscription/schemas.py +++ b/server/polar/subscription/schemas.py @@ -7,6 +7,7 @@ from polar.custom_field.data import CustomFieldDataOutputMixin from polar.discount.schemas import DiscountMinimal from polar.enums import SubscriptionRecurringInterval +from polar.kit.address import Address from polar.kit.metadata import MetadataOutputMixin from polar.kit.schemas import ( EmailStrDNS, @@ -15,15 +16,18 @@ Schema, TimestampedSchema, ) +from polar.kit.tax import TaxID from polar.models.subscription import SubscriptionStatus from polar.product.schemas import Product, ProductPriceRecurring -class SubscriptionUser(Schema): +class SubscriptionCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): email: str - public_name: str - github_username: str | None - avatar_url: str | None + email_verified: bool + name: str | None + billing_address: Address | None + tax_id: TaxID | None + organization_id: UUID4 class SubscriptionBase(IDSchema, TimestampedSchema): @@ -37,7 +41,10 @@ class SubscriptionBase(IDSchema, TimestampedSchema): started_at: datetime | None ended_at: datetime | None - user_id: UUID4 + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) + customer_id: UUID4 product_id: UUID4 price_id: UUID4 discount_id: UUID4 | None @@ -59,7 +66,10 @@ def get_amount_display(self) -> str: class Subscription(CustomFieldDataOutputMixin, MetadataOutputMixin, SubscriptionBase): - user: SubscriptionUser + customer: SubscriptionCustomer + user: SubscriptionCustomer = Field( + validation_alias="customer", deprecated="Use `customer`." + ) product: Product price: ProductPriceRecurring discount: SubscriptionDiscount | None diff --git a/server/polar/transaction/service/payment.py b/server/polar/transaction/service/payment.py index 66a53d02e2..746541c414 100644 --- a/server/polar/transaction/service/payment.py +++ b/server/polar/transaction/service/payment.py @@ -3,6 +3,7 @@ import stripe as stripe_lib from sqlalchemy import select +from polar.customer.service import customer as customer_service from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import stripe as stripe_service from polar.integrations.stripe.utils import get_expandable_id @@ -11,7 +12,6 @@ from polar.organization.service import organization as organization_service from polar.pledge.service import pledge as pledge_service from polar.postgres import AsyncSession -from polar.user.service.user import user as user_service from .base import BaseTransactionService, BaseTransactionServiceError from .processor_fee import ( @@ -46,11 +46,11 @@ async def create_payment( # Retrieve customer customer_id = None - payment_user = None + payment_customer = None payment_organization = None if charge.customer: customer_id = get_expandable_id(charge.customer) - payment_user = await user_service.get_by_stripe_customer_id( + payment_customer = await customer_service.get_by_stripe_customer_id( session, customer_id ) payment_organization = await organization_service.get_by( @@ -95,9 +95,9 @@ async def create_payment( raise PledgeDoesNotExist(charge.id, payment_intent) # If we were not able to link to a payer by Stripe Customer ID, # link from the pledge data. Happens for anonymous pledges. - if payment_user is None and payment_organization is None: + if payment_customer is None and payment_organization is None: await session.refresh(pledge, {"user", "by_organization"}) - payment_user = pledge.user + payment_customer = None # TODO: Pledge customers? payment_organization = pledge.by_organization risk = getattr(charge, "outcome", {}) @@ -112,7 +112,7 @@ async def create_payment( tax_country=tax_country, tax_state=tax_state, customer_id=customer_id, - payment_user=payment_user, + payment_customer=payment_customer, payment_organization=payment_organization, charge_id=charge.id, pledge=pledge, diff --git a/server/polar/webhook/webhooks.py b/server/polar/webhook/webhooks.py index b36f139989..34dd8b2f53 100644 --- a/server/polar/webhook/webhooks.py +++ b/server/polar/webhook/webhooks.py @@ -193,7 +193,7 @@ def get_discord_payload(self, target: User | Organization) -> str: fields: list[DiscordEmbedField] = [ {"name": "Product", "value": self.data.product.name}, {"name": "Amount", "value": amount_display}, - {"name": "Customer", "value": self.data.user.email}, + {"name": "Customer", "value": self.data.customer.email}, ] if self.data.subscription is not None: fields.append({"name": "Subscription", "value": "Yes"}) @@ -222,7 +222,7 @@ def get_slack_payload(self, target: User | Organization) -> str: fields: list[SlackText] = [ {"type": "mrkdwn", "text": f"*Product*\n{self.data.product.name}"}, {"type": "mrkdwn", "text": f"*Amount*\n{amount_display}"}, - {"type": "mrkdwn", "text": f"*Customer*\n{self.data.user.email}"}, + {"type": "mrkdwn", "text": f"*Customer*\n{self.data.customer.email}"}, ] if self.data.subscription is not None: fields.append({"type": "mrkdwn", "text": "*Subscription*\nYes"}) diff --git a/server/tests/order/test_endpoints.py b/server/tests/order/test_endpoints.py index 0d28ec42f4..8772ee8b40 100644 --- a/server/tests/order/test_endpoints.py +++ b/server/tests/order/test_endpoints.py @@ -5,7 +5,7 @@ from httpx import AsyncClient from polar.auth.scope import Scope -from polar.models import Order, Product, User, UserOrganization +from polar.models import Customer, Order, Product, UserOrganization from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_order @@ -13,9 +13,9 @@ @pytest_asyncio.fixture async def orders( - save_fixture: SaveFixture, product: Product, user_second: User + save_fixture: SaveFixture, product: Product, customer: Customer ) -> list[Order]: - return [await create_order(save_fixture, product=product, user=user_second)] + return [await create_order(save_fixture, product=product, customer=customer)] @pytest.mark.asyncio diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index 7c6180decd..3ff4de32d8 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -16,6 +16,7 @@ from polar.kit.pagination import PaginationParams from polar.models import ( Account, + Customer, Discount, Product, Subscription, @@ -55,7 +56,7 @@ def construct_stripe_invoice( charge_id: str | None = "CHARGE_ID", subscription_id: str | None = "SUBSCRIPTION_ID", subscription_details: dict[str, Any] | None = None, - customer_id: str = "CUSTOMER_ID", + customer_id: str = "STRIPE_CUSTOMER_ID", lines: list[tuple[str, bool, dict[str, str] | None]] = [("PRICE_ID", False, None)], metadata: dict[str, str] = {}, billing_reason: str = "subscription_create", @@ -124,9 +125,9 @@ async def test_user_not_organization_member( save_fixture: SaveFixture, session: AsyncSession, product: Product, - user_second: User, + customer: Customer, ) -> None: - await create_order(save_fixture, product=product, user=user_second) + await create_order(save_fixture, product=product, customer=customer) orders, count = await order_service.list( session, auth_subject, pagination=PaginationParams(1, 10) @@ -144,18 +145,18 @@ async def test_user_organization_member( session: AsyncSession, product: Product, product_organization_second: Product, - user_second: User, + customer: Customer, ) -> None: order = await create_order( save_fixture, product=product, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_1", ) await create_order( save_fixture, product=product_organization_second, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_2", ) @@ -178,7 +179,7 @@ async def test_user_organization_filter( session: AsyncSession, product: Product, product_organization_second: Product, - user_second: User, + customer: Customer, ) -> None: user_organization_second_admin = UserOrganization( user_id=user.id, organization_id=organization_second.id @@ -188,13 +189,13 @@ async def test_user_organization_filter( order_organization = await create_order( save_fixture, product=product, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_1", ) order_organization_second = await create_order( save_fixture, product=product_organization_second, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_2", ) @@ -227,18 +228,18 @@ async def test_organization( session: AsyncSession, product: Product, product_organization_second: Product, - user_second: User, + customer: Customer, ) -> None: order = await create_order( save_fixture, product=product, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_1", ) await create_order( save_fixture, product=product_organization_second, - user=user_second, + customer=customer, stripe_invoice_id="INVOICE_2", ) @@ -323,11 +324,11 @@ async def test_subscription_no_account( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice.total - (invoice.tax or 0) - assert order.user.id == subscription.user_id + assert order.customer == subscription.customer assert order.product == product assert order.product_price == product.prices[0] assert order.subscription == subscription - assert order.user.stripe_customer_id == invoice.customer + assert order.customer.stripe_customer_id == invoice.customer assert order.billing_reason == invoice.billing_reason assert order.created_at == created_datetime @@ -377,11 +378,11 @@ async def test_subscription_proration( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice.total - (invoice.tax or 0) - assert order.user.id == subscription.user_id + assert order.customer == subscription.customer assert order.product == product assert order.product_price == product.prices[0] assert order.subscription == subscription - assert order.user.stripe_customer_id == invoice.customer + assert order.customer.stripe_customer_id == invoice.customer assert order.billing_reason == invoice.billing_reason assert order.created_at == created_datetime @@ -416,11 +417,11 @@ async def test_subscription_only_proration( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice.total - (invoice.tax or 0) - assert order.user.id == subscription.user_id + assert order.customer == subscription.customer assert order.product == product assert order.product_price == product.prices[0] assert order.subscription == subscription - assert order.user.stripe_customer_id == invoice.customer + assert order.customer.stripe_customer_id == invoice.customer assert order.billing_reason == invoice.billing_reason assert order.created_at == created_datetime @@ -470,11 +471,11 @@ async def test_subscription_with_account( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice_total - assert order.user.id == subscription.user_id + assert order.customer == subscription.customer assert order.product == product assert order.product_price == product.prices[0] assert order.subscription == subscription - assert order.user.stripe_customer_id == invoice.customer + assert order.customer.stripe_customer_id == invoice.customer assert order.billing_reason == invoice.billing_reason assert order.created_at == created_datetime @@ -530,11 +531,11 @@ async def test_subscription_applied_balance( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice.total - (invoice.tax or 0) - assert order.user.id == subscription.user_id + assert order.customer == subscription.customer assert order.product == product assert order.product_price == product.prices[0] assert order.subscription == subscription - assert order.user.stripe_customer_id == invoice.customer + assert order.customer.stripe_customer_id == invoice.customer assert order.billing_reason == invoice.billing_reason assert order.created_at == created_datetime @@ -600,7 +601,7 @@ async def test_one_time_product( save_fixture: SaveFixture, product_one_time: Product, organization_account: Account, - user: User, + customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: created_datetime, created_unix_timestamp = event_creation_time @@ -613,9 +614,6 @@ async def test_one_time_product( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -642,7 +640,7 @@ async def test_one_time_product( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice_total - assert order.user.id == user.id + assert order.customer == customer assert order.product == product_one_time assert order.product_price == product_one_time.prices[0] assert order.subscription is None @@ -658,7 +656,7 @@ async def test_one_time_product( enqueue_job_mock.assert_any_call( "benefit.enqueue_benefits_grants", task="grant", - user_id=user.id, + customer_id=customer.id, product_id=product_one_time.id, order_id=order.id, ) @@ -670,7 +668,7 @@ async def test_one_time_product_discount( save_fixture: SaveFixture, product_one_time: Product, organization_account: Account, - user: User, + customer: Customer, discount_fixed_once: Discount, event_creation_time: tuple[datetime, int], ) -> None: @@ -685,9 +683,6 @@ async def test_one_time_product_discount( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -725,7 +720,7 @@ async def test_one_time_custom_price_product( save_fixture: SaveFixture, product_one_time_custom_price: Product, organization_account: Account, - user: User, + customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: created_datetime, created_unix_timestamp = event_creation_time @@ -748,9 +743,6 @@ async def test_one_time_custom_price_product( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -777,7 +769,7 @@ async def test_one_time_custom_price_product( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice_total - assert order.user.id == user.id + assert order.customer == customer assert order.product == product_one_time_custom_price assert order.product_price == product_one_time_custom_price.prices[0] assert order.subscription is None @@ -792,7 +784,7 @@ async def test_one_time_custom_price_product( enqueue_job_mock.assert_any_call( "benefit.enqueue_benefits_grants", task="grant", - user_id=user.id, + customer_id=customer.id, product_id=product_one_time_custom_price.id, order_id=order.id, ) @@ -801,9 +793,8 @@ async def test_one_time_free_product( self, enqueue_job_mock: AsyncMock, session: AsyncSession, - save_fixture: SaveFixture, product_one_time_free_price: Product, - user: User, + customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: created_datetime, created_unix_timestamp = event_creation_time @@ -821,13 +812,10 @@ async def test_one_time_free_product( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice_total - assert order.user.id == user.id + assert order.customer == customer assert order.product == product_one_time_free_price assert order.product_price == product_one_time_free_price.prices[0] assert order.subscription is None @@ -841,7 +829,7 @@ async def test_one_time_free_product( enqueue_job_mock.assert_any_call( "benefit.enqueue_benefits_grants", task="grant", - user_id=user.id, + customer_id=customer.id, product_id=product_one_time_free_price.id, order_id=order.id, ) @@ -854,7 +842,7 @@ async def test_charge_from_metadata( save_fixture: SaveFixture, product_one_time: Product, organization_account: Account, - user: User, + customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: mock = MagicMock(spec=StripeService) @@ -875,9 +863,6 @@ async def test_charge_from_metadata( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -904,7 +889,7 @@ async def test_charge_from_metadata( order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.amount == invoice_total - assert order.user.id == user.id + assert order.customer == customer assert order.product == product_one_time assert order.product_price == product_one_time.prices[0] assert order.subscription is None @@ -919,7 +904,7 @@ async def test_charge_from_metadata( enqueue_job_mock.assert_any_call( "benefit.enqueue_benefits_grants", task="grant", - user_id=user.id, + customer_id=customer.id, product_id=product_one_time.id, order_id=order.id, ) @@ -938,7 +923,7 @@ async def test_no_billing_address( mocker: MockerFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, organization_account: Account, event_creation_time: tuple[datetime, int], ) -> None: @@ -960,9 +945,6 @@ async def test_no_billing_address( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -996,7 +978,7 @@ async def test_billing_address_from_payment_method( save_fixture: SaveFixture, session: AsyncSession, product_one_time: Product, - user: User, + customer: Customer, organization_account: Account, event_creation_time: tuple[datetime, int], ) -> None: @@ -1026,9 +1008,6 @@ async def test_billing_address_from_payment_method( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -1052,9 +1031,6 @@ async def test_billing_address_from_payment_method( spec=PlatformFeeTransactionService, ) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - order = await order_service.create_order_from_stripe(session, invoice=invoice) assert order.billing_address == Address(country="US") # type: ignore assert order.created_at == created_datetime @@ -1066,7 +1042,7 @@ async def test_with_checkout( save_fixture: SaveFixture, product_one_time: Product, organization_account: Account, - user: User, + customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: publish_checkout_event_mock = mocker.patch( @@ -1089,9 +1065,6 @@ async def test_with_checkout( ) invoice_total = invoice.total - (invoice.tax or 0) - user.stripe_customer_id = "CUSTOMER_ID" - await save_fixture(user) - payment_transaction = await create_transaction( save_fixture, type=TransactionType.payment ) @@ -1134,13 +1107,13 @@ async def test_send_confirmation_email( save_fixture: SaveFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, organization: Organization, ) -> None: with WatcherEmailSender() as email_sender: mocker.patch("polar.order.service.get_email_sender", return_value=email_sender) - order = await create_order(save_fixture, product=product, user=user) + order = await create_order(save_fixture, product=product, customer=customer) async def _send_confirmation_email() -> None: await order_service.send_confirmation_email(session, organization, order) diff --git a/server/tests/transaction/conftest.py b/server/tests/transaction/conftest.py index 706bbb655b..f71873b04f 100644 --- a/server/tests/transaction/conftest.py +++ b/server/tests/transaction/conftest.py @@ -7,6 +7,7 @@ from polar.enums import AccountType from polar.models import ( Account, + Customer, ExternalOrganization, Issue, IssueReward, @@ -32,7 +33,7 @@ async def create_transaction( save_fixture: SaveFixture, *, account: Account | None = None, - payment_user: User | None = None, + payment_customer: Customer | None = None, payment_organization: Organization | None = None, type: TransactionType = TransactionType.balance, amount: int = 1000, @@ -56,7 +57,7 @@ async def create_transaction( account_amount=int(amount * 0.9) if account_currency != "usd" else amount, tax_amount=0, account=account, - payment_user=payment_user, + payment_customer=payment_customer, payment_organization=payment_organization, pledge=pledge, issue_reward=issue_reward, @@ -140,12 +141,14 @@ async def transaction_issue_reward( @pytest_asyncio.fixture async def transaction_order_subscription( - save_fixture: SaveFixture, organization: Organization, user: User + save_fixture: SaveFixture, organization: Organization, customer: Customer ) -> Order: product = await create_product(save_fixture, organization=organization) - subscription = await create_subscription(save_fixture, product=product, user=user) + subscription = await create_subscription( + save_fixture, product=product, customer=customer + ) return await create_order( - save_fixture, product=product, user=user, subscription=subscription + save_fixture, product=product, customer=customer, subscription=subscription ) @@ -190,10 +193,12 @@ async def account_transactions( @pytest_asyncio.fixture -async def user_transactions(save_fixture: SaveFixture, user: User) -> list[Transaction]: +async def user_transactions( + save_fixture: SaveFixture, customer: Customer +) -> list[Transaction]: return [ await create_transaction( - save_fixture, type=TransactionType.payment, payment_user=user + save_fixture, type=TransactionType.payment, customer=customer ), ] From 1ec505110826d7c5190f830baff7df32a1c94441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 15:39:56 +0100 Subject: [PATCH 06/47] server/subscription: fix logic and tests related to customers changes --- server/tests/subscription/test_endpoints.py | 10 +- server/tests/subscription/test_service.py | 137 +++++++++----------- 2 files changed, 68 insertions(+), 79 deletions(-) diff --git a/server/tests/subscription/test_endpoints.py b/server/tests/subscription/test_endpoints.py index 1b55f833c9..47a0533178 100644 --- a/server/tests/subscription/test_endpoints.py +++ b/server/tests/subscription/test_endpoints.py @@ -3,7 +3,7 @@ import pytest from httpx import AsyncClient -from polar.models import Organization, Product, User, UserOrganization +from polar.models import Customer, Organization, Product, UserOrganization from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_active_subscription @@ -23,14 +23,14 @@ async def test_valid( self, save_fixture: SaveFixture, client: AsyncClient, - user: User, user_organization: UserOrganization, product: Product, + customer: Customer, ) -> None: await create_active_subscription( save_fixture, product=product, - user=user, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) @@ -43,5 +43,5 @@ async def test_valid( assert json["pagination"]["total_count"] == 1 for item in json["items"]: assert "user" in item - assert "github_username" in item["user"] - assert "email" in item["user"] + assert "customer" in item + assert item["user"]["id"] == item["customer"]["id"] diff --git a/server/tests/subscription/test_service.py b/server/tests/subscription/test_service.py index 2408ab7d7c..3417691ac5 100644 --- a/server/tests/subscription/test_service.py +++ b/server/tests/subscription/test_service.py @@ -8,9 +8,11 @@ from polar.auth.models import AuthSubject from polar.authz.service import Authz from polar.checkout.eventstream import CheckoutEvent +from polar.customer.service import customer as customer_service from polar.kit.pagination import PaginationParams from polar.models import ( Benefit, + Customer, Discount, Organization, Product, @@ -29,7 +31,6 @@ SubscriptionDoesNotExist, ) from polar.subscription.service import subscription as subscription_service -from polar.user.service.user import user as user_service from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.email import WatcherEmailSender, watch_email @@ -43,9 +44,8 @@ def construct_stripe_subscription( *, - user: User | None = None, + customer: Customer | None = None, organization: Organization | None = None, - customer_id: str = "CUSTOMER_ID", price_id: str = "PRICE_ID", status: SubscriptionStatus = SubscriptionStatus.incomplete, latest_invoice: stripe_lib.Invoice | None = None, @@ -55,7 +55,6 @@ def construct_stripe_subscription( ) -> stripe_lib.Subscription: now_timestamp = datetime.now(UTC).timestamp() base_metadata: dict[str, str] = { - **({"user_id": str(user.id)} if user is not None else {}), **( {"organization_subscriber_id": str(organization.id)} if organization is not None @@ -65,7 +64,9 @@ def construct_stripe_subscription( return stripe_lib.Subscription.construct_from( { "id": "SUBSCRIPTION_ID", - "customer": customer_id, + "customer": customer.stripe_customer_id + if customer is not None + else "CUSTOMER_ID", "status": status, "items": { "data": [ @@ -92,12 +93,19 @@ def construct_stripe_subscription( def construct_stripe_customer( - *, id: str = "CUSTOMER_ID", email: str = "backer@example.com" + *, + id: str = "CUSTOMER_ID", + email: str = "customer@example.com", + name: str | None = "Customer Name", ) -> stripe_lib.Customer: return stripe_lib.Customer.construct_from( { "id": id, "email": email, + "name": name, + "address": { + "country": "FR", + }, }, None, ) @@ -138,7 +146,7 @@ async def test_valid_fixed_price( mocker: MockerFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -150,11 +158,11 @@ async def test_valid_fixed_price( price = product.prices[0] assert isinstance(price, ProductPriceFixed) subscription = await subscription_service.create_arbitrary_subscription( - session, user=user, product=product, price=price + session, customer=customer, product=product, price=price ) assert subscription.product_id == product.id - assert subscription.user_id == user.id + assert subscription.customer == customer assert subscription.amount == price.price_amount assert subscription.currency == price.price_currency assert subscription.recurring_interval == price.recurring_interval @@ -166,7 +174,7 @@ async def test_valid_custom_price( mocker: MockerFixture, session: AsyncSession, product_recurring_custom_price: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -179,14 +187,14 @@ async def test_valid_custom_price( assert isinstance(price, ProductPriceCustom) subscription = await subscription_service.create_arbitrary_subscription( session, - user=user, + customer=customer, product=product_recurring_custom_price, price=price, amount=2000, ) assert subscription.product_id == product_recurring_custom_price.id - assert subscription.user_id == user.id + assert subscription.customer == customer assert subscription.amount == 2000 assert subscription.currency == price.price_currency assert subscription.recurring_interval == price.recurring_interval @@ -198,7 +206,7 @@ async def test_valid_free_price( mocker: MockerFixture, session: AsyncSession, product_recurring_free_price: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -211,13 +219,13 @@ async def test_valid_free_price( assert isinstance(price, ProductPriceFree) subscription = await subscription_service.create_arbitrary_subscription( session, - user=user, + customer=customer, product=product_recurring_free_price, price=price, ) assert subscription.product_id == product_recurring_free_price.id - assert subscription.user_id == user.id + assert subscription.customer == customer assert subscription.amount is None assert subscription.currency is None assert subscription.recurring_interval == price.recurring_interval @@ -238,7 +246,7 @@ async def test_not_existing_subscription_tier(self, session: AsyncSession) -> No session, stripe_subscription=stripe_subscription ) - async def test_new_user( + async def test_new_customer( self, session: AsyncSession, stripe_service_mock: MagicMock, @@ -262,24 +270,26 @@ async def test_new_user( assert subscription.stripe_subscription_id == stripe_subscription.id assert subscription.product_id == product.id - user = await user_service.get(session, subscription.user_id) - assert user is not None - assert user.email == stripe_customer.email - assert user.stripe_customer_id == stripe_subscription.customer + customer = await customer_service.get_by_stripe_customer_id( + session, stripe_customer.id + ) + assert customer is not None + assert customer.email == stripe_customer.email + assert customer.stripe_customer_id == stripe_subscription.customer - async def test_existing_user( + async def test_existing_customer( self, session: AsyncSession, stripe_service_mock: MagicMock, product: Product, - user: User, + customer: Customer, ) -> None: stripe_customer = construct_stripe_customer() get_customer_mock = stripe_service_mock.get_customer get_customer_mock.return_value = stripe_customer stripe_subscription = construct_stripe_subscription( - user=user, price_id=product.prices[0].stripe_price_id + customer=customer, price_id=product.prices[0].stripe_price_id ) # then @@ -292,27 +302,21 @@ async def test_existing_user( assert subscription.stripe_subscription_id == stripe_subscription.id assert subscription.product_id == product.id - assert subscription.user_id == user.id - - # load user - user_loaded = await user_service.get(session, user.id) - assert user_loaded - - assert user_loaded.stripe_customer_id == stripe_subscription.customer + assert subscription.customer == customer async def test_set_started_at( self, session: AsyncSession, stripe_service_mock: MagicMock, product: Product, - user: User, + customer: Customer, ) -> None: stripe_customer = construct_stripe_customer() get_customer_mock = stripe_service_mock.get_customer get_customer_mock.return_value = stripe_customer stripe_subscription = construct_stripe_subscription( - user=user, + customer=customer, price_id=product.prices[0].stripe_price_id, status=SubscriptionStatus.active, ) @@ -327,12 +331,6 @@ async def test_set_started_at( assert subscription.status == SubscriptionStatus.active assert subscription.started_at is not None - # load user - user_loaded = await user_service.get(session, user.id) - assert user_loaded - - assert user_loaded.stripe_customer_id == stripe_subscription.customer - async def test_free_price( self, session: AsyncSession, @@ -367,21 +365,19 @@ async def test_subscription_update( stripe_service_mock: MagicMock, product_recurring_free_price: Product, product: Product, - user: User, + customer: Customer, ) -> None: stripe_customer = construct_stripe_customer() get_customer_mock = stripe_service_mock.get_customer get_customer_mock.return_value = stripe_customer existing_subscription = await create_active_subscription( - save_fixture, - product=product_recurring_free_price, - user=user, + save_fixture, product=product_recurring_free_price, customer=customer ) price = product.prices[0] stripe_subscription = construct_stripe_subscription( - user=user, + customer=customer, price_id=price.stripe_price_id, status=SubscriptionStatus.active, metadata={"subscription_id": str(existing_subscription.id)}, @@ -400,12 +396,6 @@ async def test_subscription_update( assert subscription.price == price assert subscription.product == product - # load user - user_loaded = await user_service.get(session, user.id) - assert user_loaded - - assert user_loaded.stripe_customer_id == stripe_subscription.customer - async def test_discount( self, session: AsyncSession, @@ -485,7 +475,7 @@ async def test_not_existing_price( session: AsyncSession, save_fixture: SaveFixture, product: Product, - user: User, + customer: Customer, ) -> None: stripe_subscription = construct_stripe_subscription( status=SubscriptionStatus.active, price_id="NOT_EXISTING_PRICE_ID" @@ -493,7 +483,7 @@ async def test_not_existing_price( subscription = await create_subscription( save_fixture, product=product, - user=user, + customer=customer, stripe_subscription_id=stripe_subscription.id, ) assert subscription.started_at is None @@ -512,7 +502,7 @@ async def test_valid( session: AsyncSession, save_fixture: SaveFixture, product: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -526,7 +516,7 @@ async def test_valid( save_fixture, product=product, price=price, - user=user, + customer=customer, stripe_subscription_id=stripe_subscription.id, ) assert subscription.started_at is None @@ -551,7 +541,7 @@ async def test_valid_cancel_at_period_end( session: AsyncSession, save_fixture: SaveFixture, product: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -567,7 +557,7 @@ async def test_valid_cancel_at_period_end( save_fixture, product=product, price=price, - user=user, + customer=customer, stripe_subscription_id=stripe_subscription.id, ) assert subscription.started_at is None @@ -593,7 +583,7 @@ async def test_valid_new_price( save_fixture: SaveFixture, product_recurring_free_price: Product, product: Product, - user: User, + customer: Customer, ) -> None: enqueue_benefits_grants_mock = mocker.patch.object( subscription_service, "enqueue_benefits_grants" @@ -609,7 +599,7 @@ async def test_valid_new_price( save_fixture, product=product_recurring_free_price, price=free_price, - user=user, + customer=customer, stripe_subscription_id=stripe_subscription.id, ) assert subscription.started_at is None @@ -635,7 +625,7 @@ async def test_valid_discount( session: AsyncSession, save_fixture: SaveFixture, product: Product, - user: User, + customer: Customer, discount_fixed_once: Discount, ) -> None: price = product.prices[0] @@ -648,7 +638,7 @@ async def test_valid_discount( save_fixture, product=product, price=price, - user=user, + customer=customer, stripe_subscription_id=stripe_subscription.id, ) assert subscription.discount is None @@ -728,7 +718,7 @@ async def test_active_subscription( call( "benefit.enqueue_benefits_grants", task="grant", - user_id=subscription.user_id, + customer_id=subscription.customer_id, product_id=product.id, subscription_id=subscription.id, ) @@ -772,7 +762,7 @@ async def test_canceled_subscription( call( "benefit.enqueue_benefits_grants", task="revoke", - user_id=subscription.user_id, + customer_id=subscription.customer_id, product_id=product.id, subscription_id=subscription.id, ) @@ -787,7 +777,7 @@ async def test_valid( session: AsyncSession, save_fixture: SaveFixture, mocker: MockerFixture, - user: User, + customer: Customer, product: Product, product_second: Product, ) -> None: @@ -796,15 +786,15 @@ async def test_valid( ) subscription_1 = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) subscription_2 = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) await create_subscription( save_fixture, product=product_second, - user=user, + customer=customer, ) # then @@ -825,13 +815,13 @@ async def test_user_not_organization_member( auth_subject: AuthSubject[User], session: AsyncSession, save_fixture: SaveFixture, - user_second: User, product: Product, + customer: Customer, ) -> None: await create_active_subscription( save_fixture, product=product, - user=user_second, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) @@ -852,14 +842,14 @@ async def test_user_organization_member( auth_subject: AuthSubject[User], session: AsyncSession, save_fixture: SaveFixture, - user_second: User, user_organization: UserOrganization, product: Product, + customer: Customer, ) -> None: await create_active_subscription( save_fixture, product=product, - user=user_second, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) @@ -881,13 +871,13 @@ async def test_organization( session: AsyncSession, save_fixture: SaveFixture, organization: Organization, - user_second: User, product: Product, + customer: Customer, ) -> None: await create_active_subscription( save_fixture, product=product, - user=user_second, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) @@ -911,8 +901,7 @@ async def test_send_confirmation_email( save_fixture: SaveFixture, session: AsyncSession, product: Product, - user: User, - organization: Organization, + customer: Customer, ) -> None: with WatcherEmailSender() as email_sender: mocker.patch( @@ -920,7 +909,7 @@ async def test_send_confirmation_email( ) subscription = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) async def _send_confirmation_email() -> None: From e79a0229a16f3889fca844ab8a3016634a4fcc36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 3 Dec 2024 16:18:58 +0100 Subject: [PATCH 07/47] server/benefit: fix logic and tests related to customers changes --- server/polar/benefit/schemas.py | 11 +- server/polar/benefit/tasks.py | 54 ++--- server/tests/benefit/benefits/test_ads.py | 13 +- .../benefit/benefits/test_downloadables.py | 135 +++++++------ .../benefit/service/test_benefit_grant.py | 185 +++++++++--------- server/tests/benefit/test_tasks.py | 62 +++--- server/tests/fixtures/downloadable.py | 28 +-- server/tests/fixtures/random_objects.py | 16 +- 8 files changed, 266 insertions(+), 238 deletions(-) diff --git a/server/polar/benefit/schemas.py b/server/polar/benefit/schemas.py index ea4f5a41cb..b03b13d82f 100644 --- a/server/polar/benefit/schemas.py +++ b/server/polar/benefit/schemas.py @@ -52,7 +52,7 @@ class BenefitProperties(Schema): ... str | None, Field( description=( - "Private note to be shared with users who have this benefit granted." + "Private note to be shared with customers who have this benefit granted." ), ), ] @@ -500,7 +500,7 @@ class BenefitLicenseKeys(BenefitBase): class BenefitGrantBase(IDSchema, TimestampedSchema): """ - A grant of a benefit to a user. + A grant of a benefit to a customer. """ id: UUID4 = Field(description="The ID of the grant.") @@ -526,7 +526,12 @@ class BenefitGrantBase(IDSchema, TimestampedSchema): order_id: UUID4 | None = Field( description="The ID of the order that granted this benefit." ) - user_id: UUID4 = Field(description="The ID of the user concerned by this grant.") + customer_id: UUID4 = Field( + description="The ID of the customer concerned by this grant." + ) + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) benefit_id: UUID4 = Field( description="The ID of the benefit concerned by this grant." ) diff --git a/server/polar/benefit/tasks.py b/server/polar/benefit/tasks.py index 0b09838d51..2110153c3e 100644 --- a/server/polar/benefit/tasks.py +++ b/server/polar/benefit/tasks.py @@ -4,12 +4,12 @@ import structlog from arq import Retry +from polar.customer.service import customer as customer_service from polar.exceptions import PolarTaskError from polar.logging import Logger from polar.models.benefit import BenefitType from polar.models.benefit_grant import BenefitGrantScopeArgs from polar.product.service.product import product as product_service -from polar.user.service.user import user as user_service from polar.worker import ( AsyncSessionMaker, JobContext, @@ -29,10 +29,10 @@ class BenefitTaskError(PolarTaskError): ... -class UserDoesNotExist(BenefitTaskError): - def __init__(self, user_id: uuid.UUID) -> None: - self.user_id = user_id - message = f"The user with id {user_id} does not exist." +class CustomerDoesNotExist(BenefitTaskError): + def __init__(self, customer_id: uuid.UUID) -> None: + self.customer_id = customer_id + message = f"The customer with id {customer_id} does not exist." super().__init__(message) @@ -68,15 +68,15 @@ def __init__(self, organization_id: uuid.UUID) -> None: async def enqueue_benefits_grants( ctx: JobContext, task: Literal["grant", "revoke"], - user_id: uuid.UUID, + customer_id: uuid.UUID, product_id: uuid.UUID, polar_context: PolarWorkerContext, **scope: Unpack[BenefitGrantScopeArgs], ) -> None: async with AsyncSessionMaker(ctx) as session: - user = await user_service.get(session, user_id) - if user is None: - raise UserDoesNotExist(user_id) + customer = await customer_service.get(session, customer_id) + if customer is None: + raise CustomerDoesNotExist(customer_id) product = await product_service.get(session, product_id) if product is None: @@ -85,22 +85,22 @@ async def enqueue_benefits_grants( resolved_scope = await resolve_scope(session, scope) await benefit_grant_service.enqueue_benefits_grants( - session, task, user, product, **resolved_scope + session, task, customer, product, **resolved_scope ) @task("benefit.grant") async def benefit_grant( ctx: JobContext, - user_id: uuid.UUID, + customer_id: uuid.UUID, benefit_id: uuid.UUID, polar_context: PolarWorkerContext, **scope: Unpack[BenefitGrantScopeArgs], ) -> None: async with AsyncSessionMaker(ctx) as session: - user = await user_service.get(session, user_id) - if user is None: - raise UserDoesNotExist(user_id) + customer = await customer_service.get(session, customer_id) + if customer is None: + raise CustomerDoesNotExist(customer_id) benefit = await benefit_service.get(session, benefit_id, loaded=True) if benefit is None: @@ -112,7 +112,7 @@ async def benefit_grant( await benefit_grant_service.grant_benefit( session, get_worker_redis(ctx), - user, + customer, benefit, attempt=ctx["job_try"], **resolved_scope, @@ -123,7 +123,7 @@ async def benefit_grant( error=str(e), defer_seconds=e.defer_seconds, benefit_id=str(benefit_id), - user_id=str(user_id), + customer_id=str(customer_id), ) raise Retry(e.defer_seconds) from e @@ -131,15 +131,15 @@ async def benefit_grant( @task("benefit.revoke") async def benefit_revoke( ctx: JobContext, - user_id: uuid.UUID, + customer_id: uuid.UUID, benefit_id: uuid.UUID, polar_context: PolarWorkerContext, **scope: Unpack[BenefitGrantScopeArgs], ) -> None: async with AsyncSessionMaker(ctx) as session: - user = await user_service.get(session, user_id) - if user is None: - raise UserDoesNotExist(user_id) + customer = await customer_service.get(session, customer_id) + if customer is None: + raise CustomerDoesNotExist(customer_id) benefit = await benefit_service.get(session, benefit_id, loaded=True) if benefit is None: @@ -151,7 +151,7 @@ async def benefit_revoke( await benefit_grant_service.revoke_benefit( session, get_worker_redis(ctx), - user, + customer, benefit, attempt=ctx["job_try"], **resolved_scope, @@ -162,7 +162,7 @@ async def benefit_revoke( error=str(e), defer_seconds=e.defer_seconds, benefit_id=str(benefit_id), - user_id=str(user_id), + customer_id=str(customer_id), ) raise Retry(e.defer_seconds) from e @@ -222,15 +222,15 @@ async def benefit_delete( @task("benefit.precondition_fulfilled") async def benefit_precondition_fulfilled( ctx: JobContext, - user_id: uuid.UUID, + customer_id: uuid.UUID, benefit_type: BenefitType, polar_context: PolarWorkerContext, ) -> None: async with AsyncSessionMaker(ctx) as session: - user = await user_service.get(session, user_id) - if user is None: - raise UserDoesNotExist(user_id) + customer = await customer_service.get(session, customer_id) + if customer is None: + raise CustomerDoesNotExist(customer_id) await benefit_grant_service.enqueue_grants_after_precondition_fulfilled( - session, user, benefit_type + session, customer, benefit_type ) diff --git a/server/tests/benefit/benefits/test_ads.py b/server/tests/benefit/benefits/test_ads.py index 8bf494360d..34b3cff90c 100644 --- a/server/tests/benefit/benefits/test_ads.py +++ b/server/tests/benefit/benefits/test_ads.py @@ -3,11 +3,8 @@ import pytest from polar.benefit.benefits.ads import BenefitAdsService -from polar.models import BenefitGrant, Organization, User -from polar.models.benefit import ( - BenefitAds, - BenefitType, -) +from polar.models import BenefitGrant, Customer, Organization +from polar.models.benefit import BenefitAds, BenefitType from polar.models.benefit_grant import BenefitGrantAdsProperties from polar.postgres import AsyncSession from polar.redis import Redis @@ -21,7 +18,7 @@ async def test_grant( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, ) -> None: benefit = cast( @@ -34,11 +31,11 @@ async def test_grant( ), ) properties: BenefitGrantAdsProperties = {"advertisement_campaign_id": "CAMPAIGN_ID"} - grant = BenefitGrant(user=user, benefit=benefit, properties=properties) + grant = BenefitGrant(customer=customer, benefit=benefit, properties=properties) benefit_ads_service = BenefitAdsService(session, redis) updated_properties = await benefit_ads_service.grant( - benefit, user, cast(BenefitGrantAdsProperties, grant.properties) + benefit, customer, cast(BenefitGrantAdsProperties, grant.properties) ) assert updated_properties == properties diff --git a/server/tests/benefit/benefits/test_downloadables.py b/server/tests/benefit/benefits/test_downloadables.py index 52823fae3d..fbf9214ea5 100644 --- a/server/tests/benefit/benefits/test_downloadables.py +++ b/server/tests/benefit/benefits/test_downloadables.py @@ -4,12 +4,7 @@ from polar.benefit.schemas import BenefitDownloadablesCreateProperties from polar.file.schemas import FileRead -from polar.models import ( - Downloadable, - Organization, - Product, - User, -) +from polar.models import Customer, Downloadable, Organization, Product from polar.models.downloadable import DownloadableStatus from polar.postgres import AsyncSession from polar.redis import Redis @@ -26,7 +21,7 @@ async def test_grant_one( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -35,16 +30,18 @@ async def test_grant_one( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( files=[uploaded_logo_jpg.id] ), ) - assert granted["files"][0] == str(uploaded_logo_jpg.id) + assert granted.get("files", [])[0] == str(uploaded_logo_jpg.id) - downloadables = await TestDownloadable.get_user_downloadables(session, user) + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == 1 @@ -58,7 +55,7 @@ async def test_grant_multiple( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -73,7 +70,7 @@ async def test_grant_multiple( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -81,12 +78,14 @@ async def test_grant_multiple( ), ) - downloadables = await TestDownloadable.get_user_downloadables(session, user) + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == len(files) for i, file in enumerate(files): - assert granted["files"][i] == str(file.id) + assert granted.get("files", [])[i] == str(file.id) downloadable = downloadables[i] assert downloadable.status == DownloadableStatus.granted assert downloadable.file_id == file.id @@ -98,7 +97,7 @@ async def test_grant_unless_archived( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -108,7 +107,7 @@ async def test_grant_unless_archived( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -121,10 +120,12 @@ async def test_grant_unless_archived( ], ), ) - assert len(granted["files"]) == 1 - assert granted["files"][0] == str(uploaded_logo_png.id) + assert len(granted.get("files", [])) == 1 + assert granted.get("files", [])[0] == str(uploaded_logo_png.id) - downloadables = await TestDownloadable.get_user_downloadables(session, user) + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == 1 @@ -138,7 +139,7 @@ async def test_revoke_one( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -147,7 +148,7 @@ async def test_revoke_one( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -158,10 +159,12 @@ async def test_revoke_one( ) # First granted - assert len(granted["files"]) == 1 - assert granted["files"][0] == str(uploaded_logo_jpg.id) + assert len(granted.get("files", [])) == 1 + assert granted.get("files", [])[0] == str(uploaded_logo_jpg.id) - downloadables = await TestDownloadable.get_user_downloadables(session, user) + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == 1 @@ -169,11 +172,11 @@ async def test_revoke_one( assert downloadable.status == DownloadableStatus.granted assert downloadable.file_id == uploaded_logo_jpg.id - await TestDownloadable.run_revoke_task(session, redis, benefit, user) + await TestDownloadable.run_revoke_task(session, redis, benefit, customer) # Now revoked - updated_downloadables = await TestDownloadable.get_user_downloadables( - session, user + updated_downloadables = await TestDownloadable.get_customer_downloadables( + session, customer ) assert updated_downloadables assert len(updated_downloadables) == 1 @@ -188,7 +191,7 @@ async def test_revoke_multiple( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -202,7 +205,7 @@ async def test_revoke_multiple( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -211,9 +214,9 @@ async def test_revoke_multiple( ) # First granted - assert len(granted["files"]) == 2 - granted_downloadables = await TestDownloadable.get_user_downloadables( - session, user + assert len(granted.get("files", [])) == 2 + granted_downloadables = await TestDownloadable.get_customer_downloadables( + session, customer ) assert len(granted_downloadables) == 2 for i, file in enumerate(files): @@ -221,11 +224,11 @@ async def test_revoke_multiple( assert grant.file_id == file.id assert grant.status == DownloadableStatus.granted - await TestDownloadable.run_revoke_task(session, redis, benefit, user) + await TestDownloadable.run_revoke_task(session, redis, benefit, customer) # Now revoked - revoked_downloadables = await TestDownloadable.get_user_downloadables( - session, user + revoked_downloadables = await TestDownloadable.get_customer_downloadables( + session, customer ) assert len(revoked_downloadables) == 2 for i, file in enumerate(files): @@ -239,7 +242,7 @@ async def test_archive_grant_retroactively( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -259,16 +262,18 @@ async def test_archive_grant_retroactively( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=properties, ) - assert len(granted["files"]) == 1 - assert granted["files"][0] == str(files[1].id) + assert len(granted.get("files", [])) == 1 + assert granted.get("files", [])[0] == str(files[1].id) - downloadables = await TestDownloadable.get_user_downloadables(session, user) + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == 1 @@ -282,11 +287,13 @@ async def test_archive_grant_retroactively( await session.flush() _, updated_granted = await TestDownloadable.run_grant_task( - session, redis, benefit, user + session, redis, benefit, customer ) - assert len(updated_granted["files"]) == 2 - downloadables = await TestDownloadable.get_user_downloadables(session, user) + assert len(updated_granted.get("files", [])) == 2 + downloadables = await TestDownloadable.get_customer_downloadables( + session, customer + ) assert downloadables assert len(downloadables) == 2 @@ -297,7 +304,7 @@ def find_downloadable(file_id: UUID) -> Downloadable | None: return None for i, file in enumerate(files): - assert updated_granted["files"][i] == str(file.id) + assert updated_granted.get("files", [])[i] == str(file.id) updated_downloadable = find_downloadable(file.id) assert updated_downloadable assert updated_downloadable.status == DownloadableStatus.granted @@ -305,13 +312,13 @@ def find_downloadable(file_id: UUID) -> Downloadable | None: assert updated_downloadable.deleted_at is None @pytest.mark.auth - async def test_archive_for_new_users( + async def test_archive_for_new_customers( self, session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, - user_second: User, + customer: Customer, + customer_second: Customer, organization: Organization, product: Product, uploaded_logo_jpg: FileRead, @@ -321,11 +328,11 @@ async def test_archive_for_new_users( uploaded_logo_jpg, uploaded_logo_png, ] - benefit, user_granted = await TestDownloadable.create_benefit_and_grant( + benefit, customer_granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -333,14 +340,14 @@ async def test_archive_for_new_users( ), ) - # First user granted all files - assert len(user_granted["files"]) == 2 - user_downloadables = await TestDownloadable.get_user_downloadables( - session, user + # First customer granted all files + assert len(customer_granted.get("files", [])) == 2 + customer_downloadables = await TestDownloadable.get_customer_downloadables( + session, customer ) - assert len(user_downloadables) == 2 + assert len(customer_downloadables) == 2 for i, file in enumerate(files): - grant = user_downloadables[i] + grant = customer_downloadables[i] assert grant.file_id == file.id assert grant.status == DownloadableStatus.granted @@ -352,20 +359,20 @@ async def test_archive_for_new_users( await session.flush() session.expunge(benefit) - # Second user granted one file + # Second customer granted one file # Since they subscribe after the 2nd file was archived - _, user_second_granted = await TestDownloadable.create_grant( + _, customer_second_granted = await TestDownloadable.create_grant( session, redis, save_fixture, benefit, - user=user_second, + customer=customer_second, product=product, ) - assert len(user_second_granted["files"]) == 1 - user_second_downloadables = await TestDownloadable.get_user_downloadables( - session, user_second + assert len(customer_second_granted.get("files", [])) == 1 + customer_second_downloadables = ( + await TestDownloadable.get_customer_downloadables(session, customer_second) ) - assert len(user_second_downloadables) == 1 - assert user_second_downloadables[0].file_id == files[1].id - assert user_second_downloadables[0].status == DownloadableStatus.granted + assert len(customer_second_downloadables) == 1 + assert customer_second_downloadables[0].file_id == files[1].id + assert customer_second_downloadables[0].status == DownloadableStatus.granted diff --git a/server/tests/benefit/service/test_benefit_grant.py b/server/tests/benefit/service/test_benefit_grant.py index f8a4e4ff4d..6834a802df 100644 --- a/server/tests/benefit/service/test_benefit_grant.py +++ b/server/tests/benefit/service/test_benefit_grant.py @@ -11,7 +11,7 @@ from polar.benefit.service.benefit_grant import ( # type: ignore[attr-defined] notification_service, ) -from polar.models import Benefit, BenefitGrant, Product, Subscription, User +from polar.models import Benefit, BenefitGrant, Customer, Product, Subscription from polar.notifications.notification import ( BenefitPreconditionErrorNotificationContextualPayload, ) @@ -46,7 +46,7 @@ async def test_not_existing_grant( session: AsyncSession, redis: Redis, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: @@ -56,11 +56,11 @@ async def test_not_existing_grant( session.expunge_all() grant = await benefit_grant_service.grant_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert grant.subscription_id == subscription.id - assert grant.user_id == user.id + assert grant.customer == customer assert grant.benefit_id == benefit_organization.id assert grant.is_granted assert grant.properties == {"external_id": "abc"} @@ -72,14 +72,12 @@ async def test_existing_grant_not_granted( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, customer=customer, benefit=benefit_organization ) await save_fixture(grant) @@ -87,7 +85,7 @@ async def test_existing_grant_not_granted( session.expunge_all() updated_grant = await benefit_grant_service.grant_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == grant.id @@ -100,14 +98,12 @@ async def test_existing_grant_already_granted( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -116,7 +112,7 @@ async def test_existing_grant_already_granted( session.expunge_all() updated_grant = await benefit_grant_service.grant_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == grant.id @@ -128,7 +124,7 @@ async def test_precondition_error( session: AsyncSession, redis: Redis, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: @@ -138,7 +134,7 @@ async def test_precondition_error( session.expunge_all() grant = await benefit_grant_service.grant_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert not grant.is_granted @@ -148,19 +144,19 @@ async def test_default_properties_value( session: AsyncSession, redis: Redis, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: benefit_service_mock.grant.side_effect = ( - lambda user, benefit, properties, **kwargs: properties + lambda customer, benefit, properties, **kwargs: properties ) # then session.expunge_all() grant = await benefit_grant_service.grant_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert grant.properties == {} @@ -173,7 +169,7 @@ async def test_not_existing_grant( session: AsyncSession, redis: Redis, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: @@ -181,7 +177,7 @@ async def test_not_existing_grant( session.expunge_all() grant = await benefit_grant_service.revoke_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert grant.subscription_id == subscription.id @@ -195,7 +191,7 @@ async def test_existing_grant_not_revoked( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: @@ -205,15 +201,15 @@ async def test_existing_grant_not_revoked( session.expunge_all() grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, + customer=customer, + benefit=benefit_organization, properties={"external_id": "abc"}, ) await save_fixture(grant) updated_grant = await benefit_grant_service.revoke_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == grant.id @@ -227,14 +223,14 @@ async def test_existing_grant_already_revoked( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, + customer=customer, + benefit=benefit_organization, ) grant.set_revoked() await save_fixture(grant) @@ -243,7 +239,7 @@ async def test_existing_grant_already_revoked( session.expunge_all() updated_grant = await benefit_grant_service.revoke_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == grant.id @@ -256,22 +252,25 @@ async def test_several_benefit_grants( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, product: Product, ) -> None: first_grant = await create_benefit_grant( - save_fixture, user, benefit_organization, subscription=subscription + save_fixture, customer, benefit_organization, subscription=subscription ) first_grant.set_granted() await save_fixture(first_grant) second_subscription = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) second_grant = await create_benefit_grant( - save_fixture, user, benefit_organization, subscription=second_subscription + save_fixture, + customer, + benefit_organization, + subscription=second_subscription, ) second_grant.set_granted() await save_fixture(second_grant) @@ -280,7 +279,7 @@ async def test_several_benefit_grants( session.expunge_all() updated_grant = await benefit_grant_service.revoke_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == first_grant.id @@ -293,7 +292,7 @@ async def test_several_benefit_grants_should_individual_revoke( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, product: Product, @@ -302,16 +301,19 @@ async def test_several_benefit_grants_should_individual_revoke( benefit_service_mock.revoke.return_value = {"message": "ok"} first_grant = await create_benefit_grant( - save_fixture, user, benefit_organization, subscription=subscription + save_fixture, customer, benefit_organization, subscription=subscription ) first_grant.set_granted() await save_fixture(first_grant) second_subscription = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) second_grant = await create_benefit_grant( - save_fixture, user, benefit_organization, subscription=second_subscription + save_fixture, + customer, + benefit_organization, + subscription=second_subscription, ) second_grant.set_granted() await save_fixture(second_grant) @@ -320,7 +322,7 @@ async def test_several_benefit_grants_should_individual_revoke( session.expunge_all() updated_grant = await benefit_grant_service.revoke_benefit( - session, redis, user, benefit_organization, subscription=subscription + session, redis, customer, benefit_organization, subscription=subscription ) assert updated_grant.id == first_grant.id @@ -341,7 +343,7 @@ async def test_subscription_scope( save_fixture: SaveFixture, product: Product, benefits: list[Benefit], - user: User, + customer: Customer, subscription: Subscription, ) -> None: enqueue_job_mock = mocker.patch( @@ -353,14 +355,14 @@ async def test_subscription_scope( ) await benefit_grant_service.enqueue_benefits_grants( - session, task, user, product, subscription=subscription + session, task, customer, product, subscription=subscription ) enqueue_job_mock.assert_has_calls( [ call( f"benefit.{task}", - user_id=subscription.user_id, + customer_id=customer.id, benefit_id=benefit.id, subscription_id=subscription.id, ) @@ -376,14 +378,14 @@ async def test_outdated_grants( product: Product, benefits: list[Benefit], subscription: Subscription, - user: User, + customer: Customer, ) -> None: enqueue_job_mock = mocker.patch( "polar.benefit.service.benefit_grant.enqueue_job" ) grant = BenefitGrant( - subscription_id=subscription.id, user_id=user.id, benefit_id=benefits[0].id + subscription=subscription, customer=customer, benefit=benefits[0] ) grant.set_granted() await save_fixture(grant) @@ -393,12 +395,12 @@ async def test_outdated_grants( ) await benefit_grant_service.enqueue_benefits_grants( - session, "grant", user, product, subscription=subscription + session, "grant", customer, product, subscription=subscription ) enqueue_job_mock.assert_any_call( "benefit.revoke", - user_id=subscription.user_id, + customer_id=customer.id, benefit_id=benefits[0].id, subscription_id=subscription.id, ) @@ -435,14 +437,14 @@ async def test_required_update_granted( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_organization_second: Benefit, benefit_service_mock: MagicMock, ) -> None: granted_grant = BenefitGrant( subscription=subscription, - user=user, + customer=customer, benefit=benefit_organization, ) granted_grant.set_granted() @@ -450,7 +452,7 @@ async def test_required_update_granted( other_benefit_grant = BenefitGrant( subscription=subscription, - user=user, + customer=customer, benefit=benefit_organization_second, ) other_benefit_grant.set_granted() @@ -480,19 +482,21 @@ async def test_required_update_revoked( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_organization_second: Benefit, benefit_service_mock: MagicMock, ) -> None: revoked_grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) revoked_grant.set_revoked() await save_fixture(revoked_grant) other_benefit_grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization_second + subscription=subscription, + customer=customer, + benefit=benefit_organization_second, ) other_benefit_grant.set_granted() await save_fixture(other_benefit_grant) @@ -520,12 +524,12 @@ async def test_revoked_grant( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_revoked() await save_fixture(grant) @@ -546,7 +550,7 @@ async def test_granted_grant( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: @@ -554,7 +558,7 @@ async def test_granted_grant( grant = BenefitGrant( subscription=subscription, - user=user, + customer=customer, benefit=benefit_organization, properties={"external_id": "abc"}, ) @@ -584,12 +588,12 @@ async def test_precondition_error( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -618,18 +622,20 @@ async def test_valid( session: AsyncSession, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_organization_second: Benefit, ) -> None: granted_grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) granted_grant.set_granted() await save_fixture(granted_grant) other_benefit_grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization_second + subscription=subscription, + customer=customer, + benefit=benefit_organization_second, ) other_benefit_grant.set_granted() await save_fixture(other_benefit_grant) @@ -658,12 +664,12 @@ async def test_revoked_grant( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_revoked() await save_fixture(grant) @@ -684,12 +690,12 @@ async def test_granted_grant( redis: Redis, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -724,7 +730,7 @@ async def test_no_notification( session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user: User, + customer: Customer, notification_send_to_user_mock: MagicMock, ) -> None: error = BenefitPreconditionError("Error") @@ -733,7 +739,7 @@ async def test_no_notification( session.expunge_all() await benefit_grant_service.handle_precondition_error( - session, error, user, benefit_organization, subscription=subscription + session, error, customer, benefit_organization, subscription=subscription ) notification_send_to_user_mock.assert_not_called() @@ -743,7 +749,7 @@ async def test_email( session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user: User, + customer: Customer, notification_send_to_user_mock: MagicMock, ) -> None: error = BenefitPreconditionError( @@ -765,7 +771,7 @@ async def test_email( await benefit_grant_service.handle_precondition_error( session, error, - user, + customer, benefit_organization, subscription=subscription_loaded, ) @@ -781,17 +787,19 @@ async def test_required_update( session: AsyncSession, save_fixture: SaveFixture, subscription: Subscription, - user: User, - user_second: User, + customer: Customer, + customer_second: Customer, benefit_organization: Benefit, ) -> None: pending_grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) await save_fixture(pending_grant) other_user_grant = BenefitGrant( - subscription=subscription, user=user_second, benefit=benefit_organization + subscription=subscription, + customer=customer_second, + benefit=benefit_organization, ) await save_fixture(other_user_grant) @@ -803,12 +811,12 @@ async def test_required_update( session.expunge_all() await benefit_grant_service.enqueue_grants_after_precondition_fulfilled( - session, user, benefit_organization.type + session, customer, benefit_organization.type ) enqueue_job_mock.assert_called_once_with( "benefit.grant", - user_id=user.id, + customer_id=customer.id, benefit_id=pending_grant.benefit_id, **pending_grant.scope, ) @@ -821,14 +829,12 @@ async def test_existing_grant_incorrect_scope( session: AsyncSession, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, product: Product, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, customer=customer, benefit=benefit_organization ) await save_fixture(grant) @@ -836,13 +842,13 @@ async def test_existing_grant_incorrect_scope( session.expunge_all() other_subscription = await create_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) - order = await create_order(save_fixture, product=product, user=user) + order = await create_order(save_fixture, product=product, customer=customer) retrieved_grant = await benefit_grant_service.get_by_benefit_and_scope( session, - user=user, + customer=customer, benefit=benefit_organization, subscription=other_subscription, order=order, @@ -854,13 +860,11 @@ async def test_existing_grant_correct_scope( session: AsyncSession, save_fixture: SaveFixture, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription_id=subscription.id, - user_id=user.id, - benefit_id=benefit_organization.id, + subscription=subscription, customer=customer, benefit=benefit_organization ) await save_fixture(grant) @@ -868,7 +872,10 @@ async def test_existing_grant_correct_scope( session.expunge_all() retrieved_grant = await benefit_grant_service.get_by_benefit_and_scope( - session, user=user, benefit=benefit_organization, subscription=subscription + session, + customer=customer, + benefit=benefit_organization, + subscription=subscription, ) assert retrieved_grant is not None assert retrieved_grant.id == grant.id diff --git a/server/tests/benefit/test_tasks.py b/server/tests/benefit/test_tasks.py index e1d1cebba2..07103aa5f2 100644 --- a/server/tests/benefit/test_tasks.py +++ b/server/tests/benefit/test_tasks.py @@ -11,7 +11,7 @@ from polar.benefit.tasks import ( # type: ignore[attr-defined] BenefitDoesNotExist, BenefitGrantDoesNotExist, - UserDoesNotExist, + CustomerDoesNotExist, benefit_delete, benefit_grant, benefit_grant_service, @@ -19,7 +19,7 @@ benefit_revoke, benefit_update, ) -from polar.models import Benefit, BenefitGrant, Subscription, User +from polar.models import Benefit, BenefitGrant, Customer, Subscription from polar.models.benefit import BenefitType from polar.postgres import AsyncSession from polar.worker import JobContext, PolarWorkerContext @@ -28,7 +28,7 @@ @pytest.mark.asyncio class TestBenefitGrant: - async def test_not_existing_user( + async def test_not_existing_customer( self, job_context: JobContext, polar_worker_context: PolarWorkerContext, @@ -39,7 +39,7 @@ async def test_not_existing_user( # then session.expunge_all() - with pytest.raises(UserDoesNotExist): + with pytest.raises(CustomerDoesNotExist): await benefit_grant( job_context, uuid.uuid4(), @@ -53,7 +53,7 @@ async def test_not_existing_benefit( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, session: AsyncSession, ) -> None: # then @@ -62,7 +62,7 @@ async def test_not_existing_benefit( with pytest.raises(BenefitDoesNotExist): await benefit_grant( job_context, - user.id, + customer.id, uuid.uuid4(), polar_worker_context, subscription_id=subscription.id, @@ -74,7 +74,7 @@ async def test_existing_benefit( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, session: AsyncSession, ) -> None: @@ -89,7 +89,7 @@ async def test_existing_benefit( await benefit_grant( job_context, - user.id, + customer.id, benefit_organization.id, polar_worker_context, subscription_id=subscription.id, @@ -103,7 +103,7 @@ async def test_retry( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, session: AsyncSession, ) -> None: @@ -120,7 +120,7 @@ async def test_retry( with pytest.raises(Retry): await benefit_grant( job_context, - user.id, + customer.id, benefit_organization.id, polar_worker_context, subscription_id=subscription.id, @@ -129,7 +129,7 @@ async def test_retry( @pytest.mark.asyncio class TestBenefitRevoke: - async def test_not_existing_user( + async def test_not_existing_customer( self, job_context: JobContext, polar_worker_context: PolarWorkerContext, @@ -140,7 +140,7 @@ async def test_not_existing_user( # then session.expunge_all() - with pytest.raises(UserDoesNotExist): + with pytest.raises(CustomerDoesNotExist): await benefit_revoke( job_context, uuid.uuid4(), @@ -154,7 +154,7 @@ async def test_not_existing_benefit( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, session: AsyncSession, ) -> None: # then @@ -163,7 +163,7 @@ async def test_not_existing_benefit( with pytest.raises(BenefitDoesNotExist): await benefit_revoke( job_context, - user.id, + customer.id, uuid.uuid4(), polar_worker_context, subscription_id=subscription.id, @@ -175,7 +175,7 @@ async def test_existing_benefit( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, session: AsyncSession, ) -> None: @@ -190,7 +190,7 @@ async def test_existing_benefit( await benefit_revoke( job_context, - user.id, + customer.id, benefit_organization.id, polar_worker_context, subscription_id=subscription.id, @@ -204,7 +204,7 @@ async def test_retry( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, session: AsyncSession, ) -> None: @@ -221,7 +221,7 @@ async def test_retry( with pytest.raises(Retry): await benefit_revoke( job_context, - user.id, + customer.id, benefit_organization.id, polar_worker_context, subscription_id=subscription.id, @@ -251,11 +251,11 @@ async def test_existing_grant( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -281,11 +281,11 @@ async def test_retry( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -327,11 +327,11 @@ async def test_existing_grant( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -357,11 +357,11 @@ async def test_retry( job_context: JobContext, polar_worker_context: PolarWorkerContext, subscription: Subscription, - user: User, + customer: Customer, benefit_organization: Benefit, ) -> None: grant = BenefitGrant( - subscription=subscription, user=user, benefit=benefit_organization + subscription=subscription, customer=customer, benefit=benefit_organization ) grant.set_granted() await save_fixture(grant) @@ -382,7 +382,7 @@ async def test_retry( @pytest.mark.asyncio class TestBenefitPreconditionFulfilled: - async def test_not_existing_user( + async def test_not_existing_customer( self, job_context: JobContext, polar_worker_context: PolarWorkerContext, @@ -391,7 +391,7 @@ async def test_not_existing_user( # then session.expunge_all() - with pytest.raises(UserDoesNotExist): + with pytest.raises(CustomerDoesNotExist): await benefit_precondition_fulfilled( job_context, uuid.uuid4(), @@ -399,13 +399,13 @@ async def test_not_existing_user( polar_worker_context, ) - async def test_existing_user( + async def test_existing_customer( self, mocker: MockerFixture, job_context: JobContext, polar_worker_context: PolarWorkerContext, session: AsyncSession, - user: User, + customer: Customer, ) -> None: enqueue_grants_after_precondition_fulfilled_mock = mocker.patch.object( benefit_grant_service, @@ -417,7 +417,7 @@ async def test_existing_user( session.expunge_all() await benefit_precondition_fulfilled( - job_context, user.id, BenefitType.custom, polar_worker_context + job_context, customer.id, BenefitType.custom, polar_worker_context ) enqueue_grants_after_precondition_fulfilled_mock.assert_called_once() diff --git a/server/tests/fixtures/downloadable.py b/server/tests/fixtures/downloadable.py index d007f80ebf..9e2fce90fe 100644 --- a/server/tests/fixtures/downloadable.py +++ b/server/tests/fixtures/downloadable.py @@ -5,7 +5,7 @@ from polar.benefit.benefits.downloadables import BenefitDownloadablesService from polar.benefit.schemas import BenefitDownloadablesCreateProperties -from polar.models import Benefit, Downloadable, File, Organization, Product, User +from polar.models import Benefit, Customer, Downloadable, File, Organization, Product from polar.models.benefit import BenefitDownloadables, BenefitType from polar.models.benefit_grant import BenefitGrantDownloadablesProperties from polar.models.subscription import SubscriptionStatus @@ -26,7 +26,7 @@ async def create_benefit_and_grant( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, properties: BenefitDownloadablesCreateProperties, @@ -42,7 +42,7 @@ async def create_benefit_and_grant( redis, save_fixture, cast(BenefitDownloadables, benefit), - user=user, + customer=customer, product=product, ) @@ -53,22 +53,22 @@ async def create_grant( redis: Redis, save_fixture: SaveFixture, benefit: BenefitDownloadables, - user: User, + customer: Customer, product: Product, ) -> tuple[BenefitDownloadables, BenefitGrantDownloadablesProperties]: subscription = await create_subscription( save_fixture, product=product, - user=user, + customer=customer, status=SubscriptionStatus.active, ) await create_benefit_grant( save_fixture, - user, + customer, benefit, subscription=subscription, ) - return await cls.run_grant_task(session, redis, benefit, user) + return await cls.run_grant_task(session, redis, benefit, customer) @classmethod async def run_grant_task( @@ -76,10 +76,10 @@ async def run_grant_task( session: AsyncSession, redis: Redis, benefit: BenefitDownloadables, - user: User, + customer: Customer, ) -> tuple[BenefitDownloadables, BenefitGrantDownloadablesProperties]: service = BenefitDownloadablesService(session, redis) - granted = await service.grant(benefit, user, {}) + granted = await service.grant(benefit, customer, {}) return benefit, granted @classmethod @@ -88,15 +88,15 @@ async def run_revoke_task( session: AsyncSession, redis: Redis, benefit: BenefitDownloadables, - user: User, + customer: Customer, ) -> tuple[BenefitDownloadables, BenefitGrantDownloadablesProperties]: service = BenefitDownloadablesService(session, redis) - revoked = await service.revoke(benefit, user, {}) + revoked = await service.revoke(benefit, customer, {}) return benefit, revoked @classmethod - async def get_user_downloadables( - cls, session: AsyncSession, user: User + async def get_customer_downloadables( + cls, session: AsyncSession, customer: Customer ) -> Sequence[Downloadable]: statement = ( sql.select(Downloadable) @@ -104,7 +104,7 @@ async def get_user_downloadables( .join(Benefit) .options(contains_eager(Downloadable.file)) .where( - Downloadable.user_id == user.id, + Downloadable.customer_id == customer.id, File.deleted_at.is_(None), File.is_uploaded == True, # noqa File.is_enabled == True, # noqa diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 205f67b23b..2773116aab 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -1272,6 +1272,18 @@ async def customer( return await create_customer(save_fixture, organization=organization) +@pytest_asyncio.fixture +async def customer_second( + save_fixture: SaveFixture, + organization: Organization, +) -> Customer: + return await create_customer( + save_fixture, + organization=organization, + stripe_customer_id="STRIPE_CUSTOMER_ID_2", + ) + + @pytest_asyncio.fixture async def subscription( save_fixture: SaveFixture, @@ -1283,13 +1295,13 @@ async def subscription( async def create_benefit_grant( save_fixture: SaveFixture, - user: User, + customer: Customer, benefit: Benefit, granted: bool | None = None, properties: BenefitGrantProperties | None = None, **scope: Unpack[BenefitGrantScope], ) -> BenefitGrant: - grant = BenefitGrant(benefit=benefit, user=user, **scope) + grant = BenefitGrant(benefit=benefit, customer=customer, **scope) if granted is not None: grant.set_granted() if granted else grant.set_revoked() if properties is not None: From be36c0d1f0833a40e3251a01bb75ec299f943690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 4 Dec 2024 10:29:04 +0100 Subject: [PATCH 08/47] server/checkout: allow to preset Customer when creating session --- server/polar/checkout/schemas.py | 8 ++++ server/polar/checkout/service.py | 40 +++++++++++++++--- server/polar/customer/service.py | 2 +- server/tests/checkout/test_service.py | 56 +++++++++++++++++++++++++ server/tests/fixtures/random_objects.py | 2 + 5 files changed, 102 insertions(+), 6 deletions(-) diff --git a/server/polar/checkout/schemas.py b/server/polar/checkout/schemas.py index e0bd1a12db..b57d337470 100644 --- a/server/polar/checkout/schemas.py +++ b/server/polar/checkout/schemas.py @@ -121,6 +121,14 @@ class CheckoutCreateBase(CustomFieldDataInputMixin, MetadataInputMixin, Schema): default=True, description=_allow_discount_codes_description ) amount: Amount | None = None + customer_id: UUID4 | None = Field( + default=None, + description=( + "ID of an existing customer in the organization. " + "The customer data will be pre-filled in the checkout form. " + "The resulting order will be linked to this customer." + ), + ) customer_name: Annotated[CustomerName | None, EmptyStrToNoneValidator] = None customer_email: CustomerEmail | None = None customer_ip_address: CustomerIPAddress | None = None diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index 71e4ee83eb..20a3739937 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -22,6 +22,7 @@ ) from polar.config import settings from polar.custom_field.data import validate_custom_field_data +from polar.customer.service import customer as customer_service from polar.discount.service import DiscountNotRedeemableError from polar.discount.service import discount as discount_service from polar.enums import PaymentProcessor @@ -304,14 +305,29 @@ async def create( ] ) from e + product = await self._eager_load_product(session, product) + subscription: Subscription | None = None customer: Customer | None = None if checkout_create.subscription_id is not None: subscription, customer = await self._get_validated_subscription( session, checkout_create.subscription_id, product.organization_id ) - - product = await self._eager_load_product(session, product) + elif checkout_create.customer_id is not None: + customer = await customer_service.get_by_id_and_organization( + session, checkout_create.customer_id, product.organization + ) + if customer is None: + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "customer_id"), + "msg": "Customer does not exist.", + "input": checkout_create.customer_id, + } + ] + ) amount = checkout_create.amount currency = None @@ -354,10 +370,24 @@ async def create( by_alias=True, ), ) - session.add(checkout) - if checkout.customer is not None and checkout.customer_email is None: - checkout.customer_email = checkout.customer.email + if checkout.customer is not None: + prefill_attributes = ( + "email", + "name", + "billing_address", + "tax_id", + ) + for attribute in prefill_attributes: + checkout_attribute = f"customer_{attribute}" + if getattr(checkout, checkout_attribute) is None: + setattr( + checkout, + checkout_attribute, + getattr(checkout.customer, attribute), + ) + + session.add(checkout) checkout = await self._update_checkout_ip_geolocation( session, checkout, ip_geolocation_client diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 8057cdea30..f83daf7b00 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -101,7 +101,7 @@ async def delete(self, session: AsyncSession, customer: Customer) -> Customer: return customer async def get_by_id_and_organization( - self, session: AsyncSession, id: str, organization: Organization + self, session: AsyncSession, id: uuid.UUID, organization: Organization ) -> Customer | None: statement = select(Customer).where( Customer.deleted_at.is_(None), diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index ef04a46f6c..c95ff00fc9 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -978,6 +978,62 @@ async def test_product_valid( assert checkout.product == product_one_time assert checkout.product_price == product_one_time.prices[0] + @pytest.mark.auth( + AuthSubjectFixture(subject="user"), + AuthSubjectFixture(subject="organization"), + ) + async def test_invalid_customer( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + user_organization: UserOrganization, + product_one_time: Product, + ) -> None: + price = product_one_time.prices[0] + assert isinstance(price, ProductPriceFixed) + + with pytest.raises(PolarRequestValidationError): + await checkout_service.create( + session, + CheckoutPriceCreate( + payment_processor=PaymentProcessor.stripe, + product_price_id=price.id, + customer_id=uuid.uuid4(), + ), + auth_subject, + ) + + @pytest.mark.auth( + AuthSubjectFixture(subject="user"), + AuthSubjectFixture(subject="organization"), + ) + async def test_valid_customer( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Organization], + user_organization: UserOrganization, + product_one_time: Product, + customer: Customer, + ) -> None: + price = product_one_time.prices[0] + assert isinstance(price, ProductPriceFixed) + + checkout = await checkout_service.create( + session, + CheckoutPriceCreate( + payment_processor=PaymentProcessor.stripe, + product_price_id=price.id, + customer_id=customer.id, + ), + auth_subject, + ) + + assert checkout.customer == customer + assert checkout.customer_email == customer.email + assert checkout.customer_name == customer.name + assert checkout.customer_billing_address == customer.billing_address + assert checkout.customer_tax_id == customer.tax_id + @pytest.mark.asyncio @pytest.mark.skip_db_asserts diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 2773116aab..b914a9a4dc 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -839,11 +839,13 @@ async def create_customer( organization: Organization, email: str = "customer@example.com", email_verified: bool = False, + name="Customer", stripe_customer_id: str = "STRIPE_CUSTOMER_ID", ) -> Customer: customer = Customer( email=email, email_verified=email_verified, + name=name, stripe_customer_id=stripe_customer_id, organization=organization, ) From 0bb56673959f96ab2bb356a38c432bd83622dbf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 5 Dec 2024 09:49:02 +0100 Subject: [PATCH 09/47] server: make customer unique by email and organization --- server/polar/checkout/service.py | 20 ++-- server/polar/customer/service.py | 92 ++++++++++++++--- server/polar/models/customer.py | 19 +++- server/polar/order/service.py | 13 +-- server/polar/subscription/service.py | 13 +-- server/tests/checkout/test_service.py | 30 ++++++ server/tests/customer/__init__.py | 0 server/tests/customer/test_service.py | 126 ++++++++++++++++++++++++ server/tests/fixtures/random_objects.py | 1 + server/tests/order/test_service.py | 37 ++++--- 10 files changed, 301 insertions(+), 50 deletions(-) create mode 100644 server/tests/customer/__init__.py create mode 100644 server/tests/customer/test_service.py diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index 20a3739937..b8993e91a1 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -1548,15 +1548,19 @@ async def _create_or_update_customer( customer = checkout.customer if customer is None: assert checkout.customer_email is not None - customer = Customer( - email=checkout.customer_email, - email_verified=False, - stripe_customer_id=None, - name=checkout.customer_name, - billing_address=checkout.customer_billing_address, - tax_id=checkout.customer_tax_id, - organization=checkout.organization, + customer = await customer_service.get_by_email_and_organization( + session, checkout.customer_email, checkout.organization ) + if customer is None: + customer = Customer( + email=checkout.customer_email, + email_verified=False, + stripe_customer_id=None, + name=checkout.customer_name, + billing_address=checkout.customer_billing_address, + tax_id=checkout.customer_tax_id, + organization=checkout.organization, + ) stripe_customer_id = customer.stripe_customer_id if stripe_customer_id is None: diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index f83daf7b00..f667adbacb 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -2,12 +2,12 @@ from collections.abc import Sequence from typing import Any -from sqlalchemy import Select, UnaryExpression, asc, desc, select +from sqlalchemy import Select, UnaryExpression, asc, desc, func, select from stripe import Customer as StripeCustomer from polar.auth.models import AuthSubject, is_organization, is_user from polar.authz.service import AccessType, Authz -from polar.exceptions import NotPermitted +from polar.exceptions import PolarRequestValidationError from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.kit.sorting import Sorting @@ -74,7 +74,30 @@ async def create( session, auth_subject, customer_create ) if not await authz.can(subject, AccessType.write, organization): - raise NotPermitted() + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "organization_id"), + "msg": "Organization not found.", + "input": organization.id, + } + ] + ) + + if await self.get_by_email_and_organization( + session, customer_create.email, organization + ): + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "email"), + "msg": "A customer with this email address already exists.", + "input": customer_create.email, + } + ] + ) customer = Customer( organization=organization, @@ -87,6 +110,24 @@ async def create( async def update( self, session: AsyncSession, customer: Customer, customer_update: CustomerUpdate ) -> Customer: + if ( + customer_update.email is not None + and customer.email.lower() != customer_update.email.lower() + and await self.get_by_email_and_organization( + session, customer_update.email, customer.organization + ) + ): + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "email"), + "msg": "A customer with this email address already exists.", + "input": customer_update.email, + } + ] + ) + for attr, value in customer_update.model_dump(exclude_unset=True).items(): setattr(customer, attr, value) @@ -111,6 +152,16 @@ async def get_by_id_and_organization( result = await session.execute(statement) return result.scalar_one_or_none() + async def get_by_email_and_organization( + self, session: AsyncSession, email: str, organization: Organization + ) -> Customer | None: + statement = select(Customer).where( + func.lower(Customer.email) == email.lower(), + Customer.organization_id == organization.id, + ) + result = await session.execute(statement) + return result.scalar_one_or_none() + async def get_by_stripe_customer_id( self, session: AsyncSession, stripe_customer_id: str ) -> Customer | None: @@ -121,21 +172,36 @@ async def get_by_stripe_customer_id( result = await session.execute(statement) return result.scalar_one_or_none() - async def create_from_stripe_customer( + async def get_or_create_from_stripe_customer( self, session: AsyncSession, stripe_customer: StripeCustomer, organization: Organization, ) -> Customer: - customer = Customer( - email=stripe_customer.email, - email_verified=False, - stripe_customer_id=stripe_customer.id, - name=stripe_customer.name, - billing_address=stripe_customer.address, - # TODO: tax_id, - organization=organization, - ) + """ + Get or create a customer from a Stripe customer object. + + Make a first lookup by the Stripe customer ID, then by the email address. + + If the customer does not exist, create a new one. + """ + customer = await self.get_by_stripe_customer_id(session, stripe_customer.id) + assert stripe_customer.email is not None + if customer is None: + customer = await self.get_by_email_and_organization( + session, stripe_customer.email, organization + ) + if customer is None: + customer = Customer( + email=stripe_customer.email, + email_verified=False, + stripe_customer_id=stripe_customer.id, + name=stripe_customer.name, + billing_address=stripe_customer.address, + # TODO: tax_id, + organization=organization, + ) + customer.stripe_customer_id = stripe_customer.id session.add(customer) return customer diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index b5d1adc380..ed46b7734a 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -1,7 +1,15 @@ from typing import TYPE_CHECKING from uuid import UUID -from sqlalchemy import Boolean, ForeignKey, String, Uuid +from sqlalchemy import ( + Boolean, + Column, + ForeignKey, + Index, + String, + Uuid, + func, +) from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from polar.kit.address import Address, AddressType @@ -15,6 +23,15 @@ class Customer(MetadataMixin, RecordModel): __tablename__ = "customers" + __table_args__ = ( + Index("ix_customers_email_case_insensitive", func.lower(Column("email"))), + Index( + "ix_customers_organization_id_email_case_insensitive", + "organization_id", + func.lower(Column("email")), + unique=True, + ), + ) email: Mapped[str] = mapped_column(String(320), nullable=False) email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index 47cf0862d7..c1872a0027 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -409,16 +409,13 @@ async def create_order_from_stripe( # Get or create customer assert invoice.customer is not None - stripe_customer_id = get_expandable_id(invoice.customer) if customer is None: - customer = await customer_service.get_by_stripe_customer_id( - session, stripe_customer_id + stripe_customer = await stripe_service.get_customer( + get_expandable_id(invoice.customer) + ) + customer = await customer_service.get_or_create_from_stripe_customer( + session, stripe_customer, product.organization ) - if customer is None: - stripe_customer = await stripe_service.get_customer(stripe_customer_id) - customer = await customer_service.create_from_stripe_customer( - session, stripe_customer, product.organization - ) order.customer = customer session.add(order) diff --git a/server/polar/subscription/service.py b/server/polar/subscription/service.py index fcb5214565..3ce6065991 100644 --- a/server/polar/subscription/service.py +++ b/server/polar/subscription/service.py @@ -459,15 +459,12 @@ async def create_subscription_from_stripe( # Take customer from existing subscription, or retrieve it from Stripe Customer ID if subscription.customer is None: - stripe_customer_id = get_expandable_id(stripe_subscription.customer) - customer = await customer_service.get_by_stripe_customer_id( - session, stripe_customer_id + stripe_customer = await stripe_service.get_customer( + get_expandable_id(stripe_subscription.customer) + ) + customer = await customer_service.get_or_create_from_stripe_customer( + session, stripe_customer, subscription_tier_org ) - if customer is None: - stripe_customer = await stripe_service.get_customer(stripe_customer_id) - customer = await customer_service.create_from_stripe_customer( - session, stripe_customer, subscription_tier_org - ) subscription.customer = customer session.add(subscription) diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index c95ff00fc9..75743e2571 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -2036,6 +2036,36 @@ async def test_valid_stripe_existing_customer( assert checkout.status == CheckoutStatus.confirmed stripe_service_mock.update_customer.assert_called_once() + async def test_valid_stripe_existing_customer_email( + self, + stripe_service_mock: MagicMock, + session: AsyncSession, + locker: Locker, + checkout_one_time_fixed: Checkout, + customer: Customer, + ) -> None: + stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( + client_secret="CLIENT_SECRET", status="succeeded" + ) + + checkout = await checkout_service.confirm( + session, + locker, + checkout_one_time_fixed, + CheckoutConfirmStripe.model_validate( + { + "confirmation_token_id": "CONFIRMATION_TOKEN_ID", + "customer_email": customer.email, + "customer_name": "Customer Name", + "customer_billing_address": {"country": "FR"}, + } + ), + ) + + assert checkout.status == CheckoutStatus.confirmed + assert checkout.customer == customer + stripe_service_mock.update_customer.assert_called_once() + def build_stripe_payment_intent( *, diff --git a/server/tests/customer/__init__.py b/server/tests/customer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/customer/test_service.py b/server/tests/customer/test_service.py new file mode 100644 index 0000000000..4fe0eb892b --- /dev/null +++ b/server/tests/customer/test_service.py @@ -0,0 +1,126 @@ +from typing import Any + +import pytest + +from polar.auth.models import AuthSubject, is_user +from polar.authz.service import Authz +from polar.customer.schemas import CustomerCreate, CustomerUpdate +from polar.customer.service import customer as customer_service +from polar.exceptions import PolarRequestValidationError +from polar.models import Customer, Organization, User, UserOrganization +from polar.postgres import AsyncSession +from tests.fixtures.auth import AuthSubjectFixture + + +@pytest.fixture +def authz(session: AsyncSession) -> Authz: + return Authz(session) + + +@pytest.mark.asyncio +@pytest.mark.skip_db_asserts +class TestCreate: + @pytest.mark.auth + async def test_not_accessible_organization( + self, + session: AsyncSession, + authz: Authz, + auth_subject: AuthSubject[User], + organization: Organization, + ) -> None: + with pytest.raises(PolarRequestValidationError): + await customer_service.create( + session, + authz, + CustomerCreate( + email="customer@example.com", organization_id=organization.id + ), + auth_subject, + ) + + @pytest.mark.auth( + AuthSubjectFixture(subject="user"), AuthSubjectFixture(subject="organization") + ) + async def test_existing_email( + self, + session: AsyncSession, + authz: Authz, + auth_subject: AuthSubject[User | Organization], + organization: Organization, + user_organization: UserOrganization, + customer: Customer, + ) -> None: + payload: dict[str, Any] = { + "email": customer.email.upper() # Check case-insensitive index + } + if is_user(auth_subject): + payload["organization_id"] = str(organization.id) + + with pytest.raises(PolarRequestValidationError): + await customer_service.create( + session, + authz, + CustomerCreate.model_validate(payload), + auth_subject, + ) + await session.flush() + + @pytest.mark.auth( + AuthSubjectFixture(subject="user"), AuthSubjectFixture(subject="organization") + ) + async def test_valid( + self, + session: AsyncSession, + authz: Authz, + auth_subject: AuthSubject[User | Organization], + organization: Organization, + user_organization: UserOrganization, + ) -> None: + payload: dict[str, Any] = {"email": "customer@example.com"} + if is_user(auth_subject): + payload["organization_id"] = str(organization.id) + + customer = await customer_service.create( + session, + authz, + CustomerCreate.model_validate(payload), + auth_subject, + ) + await session.flush() + + assert customer.email == "customer@example.com" + + +@pytest.mark.asyncio +@pytest.mark.skip_db_asserts +class TestUpdate: + async def test_existing_email( + self, session: AsyncSession, customer: Customer, customer_second: Customer + ) -> None: + with pytest.raises(PolarRequestValidationError): + await customer_service.update( + session, + customer, + CustomerUpdate(email=customer_second.email), + ) + await session.flush() + + @pytest.mark.parametrize( + "email", + [ + pytest.param("customer@example.com", id="same email"), + pytest.param("customer.updated@example.cm", id="different email"), + ], + ) + async def test_valid( + self, email: str, session: AsyncSession, customer: Customer + ) -> None: + customer = await customer_service.update( + session, + customer, + CustomerUpdate(email=email, name="John"), + ) + await session.flush() + + assert customer.email == email + assert customer.name == "John" diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index b914a9a4dc..9b0d6ae5f1 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -1282,6 +1282,7 @@ async def customer_second( return await create_customer( save_fixture, organization=organization, + email="customer.second@example.com", stripe_customer_id="STRIPE_CUSTOMER_ID_2", ) diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index 3ff4de32d8..a5199d067d 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -1,5 +1,6 @@ import time from datetime import datetime +from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -103,6 +104,21 @@ def construct_stripe_invoice( ) +@pytest.fixture(autouse=True) +def stripe_service_mock(mocker: MockerFixture, customer: Customer) -> MagicMock: + mock = MagicMock(spec=StripeService) + mocker.patch("polar.order.service.stripe_service", new=mock) + + mock.get_customer.return_value = SimpleNamespace( + id=customer.stripe_customer_id, + email=customer.email, + name=customer.name, + address=customer.billing_address, + ) + + return mock + + @pytest.fixture def enqueue_job_mock(mocker: MockerFixture) -> AsyncMock: return mocker.patch("polar.order.service.enqueue_job") @@ -837,6 +853,7 @@ async def test_one_time_free_product( async def test_charge_from_metadata( self, enqueue_job_mock: AsyncMock, + stripe_service_mock: MagicMock, mocker: MockerFixture, session: AsyncSession, save_fixture: SaveFixture, @@ -845,10 +862,10 @@ async def test_charge_from_metadata( customer: Customer, event_creation_time: tuple[datetime, int], ) -> None: - mock = MagicMock(spec=StripeService) - mocker.patch("polar.order.service.stripe_service", new=mock) - mock.get_payment_intent.return_value = stripe_lib.PaymentIntent.construct_from( - {"latest_charge": "CHARGE_ID"}, key=None + stripe_service_mock.get_payment_intent.return_value = ( + stripe_lib.PaymentIntent.construct_from( + {"latest_charge": "CHARGE_ID"}, key=None + ) ) created_datetime, created_unix_timestamp = event_creation_time @@ -919,17 +936,15 @@ async def test_charge_from_metadata( async def test_no_billing_address( self, customer_address: dict[str, Any] | None, + stripe_service_mock: MagicMock, save_fixture: SaveFixture, mocker: MockerFixture, session: AsyncSession, product: Product, - customer: Customer, organization_account: Account, event_creation_time: tuple[datetime, int], ) -> None: - mock = MagicMock(spec=StripeService) - mocker.patch("polar.order.service.stripe_service", new=mock) - mock.get_charge.return_value = stripe_lib.Charge.construct_from( + stripe_service_mock.get_charge.return_value = stripe_lib.Charge.construct_from( {"id": "CHARGE_ID", "payment_method_details": None}, key=None, ) @@ -975,16 +990,14 @@ async def test_no_billing_address( async def test_billing_address_from_payment_method( self, mocker: MockerFixture, + stripe_service_mock: MagicMock, save_fixture: SaveFixture, session: AsyncSession, product_one_time: Product, - customer: Customer, organization_account: Account, event_creation_time: tuple[datetime, int], ) -> None: - mock = MagicMock(spec=StripeService) - mocker.patch("polar.order.service.stripe_service", new=mock) - mock.get_charge.return_value = stripe_lib.Charge.construct_from( + stripe_service_mock.get_charge.return_value = stripe_lib.Charge.construct_from( { "id": "CHARGE_ID", "payment_method_details": { From 5c36457c6de77b05d69c4589580030a0535e6e2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 5 Dec 2024 11:24:03 +0100 Subject: [PATCH 10/47] server/benefit: add oauth acconts to Customer and revamp Discord/GitHub benefits with a require action pattern --- server/polar/benefit/benefits/__init__.py | 4 +- server/polar/benefit/benefits/base.py | 25 +--- server/polar/benefit/benefits/discord.py | 76 +++-------- .../benefit/benefits/github_repository.py | 113 ++++------------ server/polar/benefit/service/benefit_grant.py | 100 +------------- server/polar/benefit/tasks.py | 18 --- server/polar/integrations/discord/service.py | 92 +++++++++++-- server/polar/integrations/github/endpoints.py | 8 -- server/polar/models/benefit_grant.py | 8 +- server/polar/models/customer.py | 39 +++++- server/polar/notifications/notification.py | 33 +---- .../benefit/service/test_benefit_grant.py | 128 +----------------- server/tests/benefit/test_tasks.py | 45 ------ 13 files changed, 179 insertions(+), 510 deletions(-) diff --git a/server/polar/benefit/benefits/__init__.py b/server/polar/benefit/benefits/__init__.py index 04708ca8f2..06e8e057c7 100644 --- a/server/polar/benefit/benefits/__init__.py +++ b/server/polar/benefit/benefits/__init__.py @@ -11,7 +11,7 @@ from .ads import BenefitAdsService from .base import ( - BenefitPreconditionError, + BenefitActionRequiredError, BenefitPropertiesValidationError, BenefitRetriableError, BenefitServiceError, @@ -43,9 +43,9 @@ def get_benefit_service( __all__ = [ + "BenefitActionRequiredError", "BenefitServiceProtocol", "BenefitPropertiesValidationError", - "BenefitPreconditionError", "BenefitRetriableError", "BenefitServiceError", "get_benefit_service", diff --git a/server/polar/benefit/benefits/base.py b/server/polar/benefit/benefits/base.py index f599e390d6..dc2d8e5ee7 100644 --- a/server/polar/benefit/benefits/base.py +++ b/server/polar/benefit/benefits/base.py @@ -4,9 +4,6 @@ from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError from polar.models import Benefit, Customer, Organization, User from polar.models.benefit import BenefitProperties -from polar.notifications.notification import ( - BenefitPreconditionErrorNotificationContextualPayload, -) from polar.postgres import AsyncSession from polar.redis import Redis @@ -46,28 +43,13 @@ def __init__(self, defer_seconds: int) -> None: super().__init__(message) -class BenefitPreconditionError(BenefitServiceError): +class BenefitActionRequiredError(BenefitServiceError): """ - Some conditions are missing to grant the benefit. + An action is required from the customer before granting the benefit. - It accepts a payload schema. - When set, a notification will be sent to the backer to explain them what happened. + Typically, we need the customer to connect an external OAuth account. """ - def __init__( - self, - message: str, - *, - payload: BenefitPreconditionErrorNotificationContextualPayload | None = None, - ) -> None: - """ - Args: - message: The plain error message. - payload: The payload to build the notification. - """ - self.payload = payload - super().__init__(message) - B = TypeVar("B", bound=Benefit, contravariant=True) BP = TypeVar("BP", bound=BenefitProperties) @@ -122,7 +104,6 @@ async def grant( Raises: BenefitRetriableError: An temporary error occured, we should be able to retry later. - BenefitPreconditionError: Some conditions are missing to grant the benefit. """ ... diff --git a/server/polar/benefit/benefits/discord.py b/server/polar/benefit/benefits/discord.py index aa01ac96a4..7211b22da8 100644 --- a/server/polar/benefit/benefits/discord.py +++ b/server/polar/benefit/benefits/discord.py @@ -4,20 +4,17 @@ import structlog from polar.auth.models import AuthSubject -from polar.config import settings -from polar.integrations.discord.service import DiscordAccountNotConnected from polar.integrations.discord.service import discord_bot as discord_bot_service -from polar.integrations.discord.service import discord_user as discord_user_service +from polar.integrations.discord.service import ( + discord_customer as discord_customer_service, +) from polar.logging import Logger from polar.models import Customer, Organization, User from polar.models.benefit import BenefitDiscord, BenefitDiscordProperties from polar.models.benefit_grant import BenefitGrantDiscordProperties -from polar.notifications.notification import ( - BenefitPreconditionErrorNotificationContextualPayload, -) from .base import ( - BenefitPreconditionError, + BenefitActionRequiredError, BenefitPropertiesValidationError, BenefitRetriableError, BenefitServiceProtocol, @@ -25,42 +22,6 @@ log: Logger = structlog.get_logger() -precondition_error_subject_template = ( - "Action required: get access to {organization_name}'s Discord server" -) -precondition_error_body_template = """ -

Hi,

-

You just subscribed to {scope_name} from {organization_name}. Thank you!

-

As you may know, it includes an access to a private Discord server. To grant you access, we need you to link your Discord account on Polar.

-

Once done, you'll automatically be added to {organization_name}'s Discord server.

- - - - - - - - - - - - -""" - class BenefitDiscordService( BenefitServiceProtocol[ @@ -96,22 +57,19 @@ async def grant( ) await self.revoke(benefit, customer, grant_properties, attempt=attempt) - # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User + if (account_id := grant_properties.get("account_id")) is None: + raise BenefitActionRequiredError( + "The customer needs to connect their Discord account" + ) - try: - account = await discord_user_service.get_oauth_account(self.session, user) - except DiscordAccountNotConnected as e: - raise BenefitPreconditionError( - "Discord account not linked", - payload=BenefitPreconditionErrorNotificationContextualPayload( - subject_template=precondition_error_subject_template, - body_template=precondition_error_body_template, - extra_context={"url": settings.generate_frontend_url("/settings")}, - ), - ) from e + oauth_account = await discord_customer_service.get_oauth_account( + self.session, customer, account_id + ) try: - await discord_bot_service.add_member(self.session, guild_id, role_id, user) + await discord_bot_service.add_member( + guild_id, role_id, oauth_account.account_id, oauth_account.access_token + ) except httpx.HTTPError as e: error_bound_logger = bound_logger.bind(error=str(e)) if isinstance(e, httpx.HTTPStatusError): @@ -123,13 +81,11 @@ async def grant( bound_logger.debug("Benefit granted") - # Store guild, role and account IDs as it may change for various reasons: - # * The benefit is updated - # * The user disconnects or changes their Discord account + # Store guild, and role an as it may change if the benefit is updated return { + **grant_properties, "guild_id": guild_id, "role_id": role_id, - "account_id": account.account_id, } async def revoke( diff --git a/server/polar/benefit/benefits/github_repository.py b/server/polar/benefit/benefits/github_repository.py index a9df4e078e..74736e162a 100644 --- a/server/polar/benefit/benefits/github_repository.py +++ b/server/polar/benefit/benefits/github_repository.py @@ -19,15 +19,11 @@ BenefitGitHubRepositoryProperties, ) from polar.models.benefit_grant import BenefitGrantGitHubRepositoryProperties -from polar.models.user import OAuthPlatform -from polar.notifications.notification import ( - BenefitPreconditionErrorNotificationContextualPayload, -) from polar.posthog import posthog from polar.repository.service import repository as repository_service from .base import ( - BenefitPreconditionError, + BenefitActionRequiredError, BenefitPropertiesValidationError, BenefitRetriableError, BenefitServiceProtocol, @@ -35,40 +31,6 @@ log: Logger = structlog.get_logger() -precondition_error_subject_template = "Action required: get access to {extra_context[repository_owner]}/{extra_context[repository_name]} repository" -precondition_error_body_template = """ -

Hi,

-

You just subscribed to {scope_name} from {organization_name}. Thank you!

-

As you may know, it includes an access to {extra_context[repository_owner]}/{extra_context[repository_name]} repository on GitHub. To grant you access, we need you to link your GitHub account on Polar.

-

Once done, you'll automatically be invited to the repository.

- - - - - - - - - - - - -""" - class BenefitGitHubRepositoryService( BenefitServiceProtocol[ @@ -92,28 +54,24 @@ async def grant( ) bound_logger.debug("Grant benefit") - client = await self._get_github_app_client(bound_logger, benefit) + client = await self._get_github_app_client(benefit) repository_owner = benefit.properties["repository_owner"] repository_name = benefit.properties["repository_name"] permission = benefit.properties["permission"] - # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User + if (account_id := grant_properties.get("account_id")) is None: + raise BenefitActionRequiredError( + "The customer needs to connect their GitHub account" + ) + + oauth_account = customer.get_oauth_account( + f"github:{settings.GITHUB_CLIENT_ID}:{account_id}" + ) - # When inviting users: Use the users identity from the "main" Polar GitHub App - oauth_account = user.get_oauth_account(OAuthPlatform.github) if oauth_account is None or oauth_account.account_username is None: - raise BenefitPreconditionError( - "GitHub account not linked", - payload=BenefitPreconditionErrorNotificationContextualPayload( - subject_template=precondition_error_subject_template, - body_template=precondition_error_body_template, - extra_context={ - "repository_owner": repository_owner, - "repository_name": repository_name, - "url": settings.generate_frontend_url("/settings"), - }, - ), + raise BenefitActionRequiredError( + "The customer needs to connect their GitHub account" ) # If we already granted this benefit, make sure we revoke the previous config @@ -131,7 +89,7 @@ async def grant( or repository_name != grant_properties.get("repository_name") or invitation is not None ): - await self.revoke(benefit, user, grant_properties, attempt=attempt) + await self.revoke(benefit, customer, grant_properties, attempt=attempt) # The permission changed, and the invitation is already accepted elif permission != grant_properties.get("permission"): # The permission change will be handled by the add_collaborator call @@ -153,6 +111,7 @@ async def grant( # Store repository and permission to compare on update return { + **grant_properties, "repository_owner": repository_owner, "repository_name": repository_name, "permission": permission, @@ -175,16 +134,24 @@ async def revoke( bound_logger.info("skipping revoke for old version of this benefit type") return {} - client = await self._get_github_app_client(bound_logger, benefit) + client = await self._get_github_app_client(benefit) repository_owner = benefit.properties["repository_owner"] repository_name = benefit.properties["repository_name"] - # TODO: we need to revamp this, since we now need to get an account from a Customer instead of a User + if (account_id := grant_properties.get("account_id")) is None: + raise BenefitActionRequiredError( + "The customer needs to connect their GitHub account" + ) + + oauth_account = customer.get_oauth_account( + f"github:{settings.GITHUB_CLIENT_ID}:{account_id}" + ) - oauth_account = user.get_oauth_account(OAuthPlatform.github) if oauth_account is None or oauth_account.account_username is None: - raise + raise BenefitActionRequiredError( + "The customer needs to connect their GitHub account" + ) invitation = await self._get_invitation( client, @@ -250,14 +217,9 @@ async def validate_properties( assert is_user(auth_subject) user = auth_subject.subject - # old style - if properties["repository_id"]: - return await self._validate_properties_repository_id(user, properties) - repository_owner = properties["repository_owner"] repository_name = properties["repository_name"] - # new style oauth = await github_repository_benefit_user_service.get_oauth_account( self.session, user ) @@ -418,29 +380,8 @@ async def _get_invitation( return None async def _get_github_app_client( - self, - logger: Logger, - benefit: BenefitGitHubRepository, + self, benefit: BenefitGitHubRepository ) -> GitHub[AppInstallationAuthStrategy]: - # Old integrations, using the "Polar" GitHub App - if benefit.properties["repository_id"]: - logger.debug("using legacy integration") - repository_id = benefit.properties["repository_id"] - repository = await repository_service.get( - self.session, repository_id, load_organization=True - ) - assert repository is not None - organization = repository.organization - assert organization is not None - installation_id = organization.installation_id - assert installation_id is not None - return github.get_app_installation_client( - installation_id, redis=self.redis, app=github.GitHubApp.polar - ) - - # New integration, using the "Repository Benefit" GitHub App - logger.debug("using Repository Benefit app integration") - repository_owner = benefit.properties["repository_owner"] repository_name = benefit.properties["repository_name"] installation = ( diff --git a/server/polar/benefit/service/benefit_grant.py b/server/polar/benefit/service/benefit_grant.py index d2577df3d3..8b98f007df 100644 --- a/server/polar/benefit/service/benefit_grant.py +++ b/server/polar/benefit/service/benefit_grant.py @@ -6,7 +6,8 @@ from sqlalchemy import select from sqlalchemy.orm import joinedload -from polar.benefit.benefits import BenefitPreconditionError, get_benefit_service +from polar.benefit.benefits import get_benefit_service +from polar.benefit.benefits.base import BenefitActionRequiredError from polar.benefit.schemas import BenefitGrantWebhook from polar.customer.service import customer as customer_service from polar.eventstream.service import publish as eventstream_publish @@ -18,20 +19,13 @@ from polar.models.benefit import BenefitProperties, BenefitType from polar.models.benefit_grant import BenefitGrantPropertiesBase, BenefitGrantScope from polar.models.webhook_endpoint import WebhookEventType -from polar.notifications.notification import ( - BenefitPreconditionErrorNotificationPayload, - NotificationType, -) -from polar.notifications.service import PartialNotification -from polar.notifications.service import notifications as notification_service -from polar.organization.service import organization as organization_service from polar.postgres import AsyncSession, sql from polar.redis import Redis from polar.webhook.service import webhook as webhook_service from polar.webhook.webhooks import WebhookPayloadTypeAdapter from polar.worker import enqueue_job -from .benefit_grant_scope import resolve_scope, scope_to_args +from .benefit_grant_scope import scope_to_args log: Logger = structlog.get_logger() @@ -159,8 +153,7 @@ async def grant_benefit( grant.properties, attempt=attempt, ) - except BenefitPreconditionError as e: - await self.handle_precondition_error(session, e, customer, benefit, **scope) + except BenefitActionRequiredError: grant.granted_at = None else: grant.properties = properties @@ -331,9 +324,7 @@ async def update_benefit_grant( update=True, attempt=attempt, ) - except BenefitPreconditionError as e: - scope = await resolve_scope(session, grant.scope) - await self.handle_precondition_error(session, e, customer, benefit, **scope) + except BenefitActionRequiredError: grant.granted_at = None else: grant.properties = properties @@ -397,87 +388,6 @@ async def delete_benefit_grant( ) return grant - async def handle_precondition_error( - self, - session: AsyncSession, - error: BenefitPreconditionError, - customer: Customer, - benefit: Benefit, - **scope: Unpack[BenefitGrantScope], - ) -> None: - if error.payload is None: - log.warning( - "A precondition error was raised but the customer was not notified. " - "We probably should implement a notification for this error.", - benefit_id=str(benefit.id), - customer_id=str(customer.id), - scope=scope, - ) - return - - log.info( - "Precondition error while granting benefit. Customer was informed.", - benefit_id=str(benefit.id), - customer_id=str(customer.id), - ) - - # Disable the notification for now as it's a bit noisy for some use-cases - # We'll change how benefits are granted in the future so this won't be needed - return - - scope_name = "" - organization_name = "" - if subscription := scope.get("subscription"): - await session.refresh(subscription, {"product"}) - scope_name = subscription.product.name - subscription_tier = subscription.product - managing_organization = await organization_service.get( - session, subscription_tier.organization_id - ) - assert managing_organization is not None - organization_name = managing_organization.slug - - notification_payload = BenefitPreconditionErrorNotificationPayload( - scope_name=scope_name, - benefit_id=benefit.id, - benefit_description=benefit.description, - organization_name=organization_name, - **error.payload.model_dump(), - ) - - await notification_service.send_to_user( - session=session, - user_id=user.id, - notif=PartialNotification( - type=NotificationType.benefit_precondition_error, - payload=notification_payload, - ), - ) - - async def enqueue_grants_after_precondition_fulfilled( - self, - session: AsyncSession, - customer: Customer, - benefit_type: BenefitType, - ) -> None: - log.info( - "Enqueueing benefit grants after precondition fulfilled", - customer_id=str(customer.id), - benefit_type=benefit_type, - ) - - grants = await self._get_by_customer_and_benefit_type( - session, customer, benefit_type - ) - for grant in grants: - if not grant.is_granted and not grant.is_revoked: - enqueue_job( - "benefit.grant", - customer_id=customer.id, - benefit_id=grant.benefit_id, - **grant.scope, - ) - async def get_by_benefit_and_scope( self, session: AsyncSession, diff --git a/server/polar/benefit/tasks.py b/server/polar/benefit/tasks.py index 2110153c3e..ccd3f5d757 100644 --- a/server/polar/benefit/tasks.py +++ b/server/polar/benefit/tasks.py @@ -7,7 +7,6 @@ from polar.customer.service import customer as customer_service from polar.exceptions import PolarTaskError from polar.logging import Logger -from polar.models.benefit import BenefitType from polar.models.benefit_grant import BenefitGrantScopeArgs from polar.product.service.product import product as product_service from polar.worker import ( @@ -217,20 +216,3 @@ async def benefit_delete( benefit_grant_id=str(benefit_grant_id), ) raise Retry(e.defer_seconds) from e - - -@task("benefit.precondition_fulfilled") -async def benefit_precondition_fulfilled( - ctx: JobContext, - customer_id: uuid.UUID, - benefit_type: BenefitType, - polar_context: PolarWorkerContext, -) -> None: - async with AsyncSessionMaker(ctx) as session: - customer = await customer_service.get(session, customer_id) - if customer is None: - raise CustomerDoesNotExist(customer_id) - - await benefit_grant_service.enqueue_grants_after_precondition_fulfilled( - session, customer, benefit_type - ) diff --git a/server/polar/integrations/discord/service.py b/server/polar/integrations/discord/service.py index 53596f0744..e730ffdf3f 100644 --- a/server/polar/integrations/discord/service.py +++ b/server/polar/integrations/discord/service.py @@ -4,11 +4,10 @@ from polar.config import settings from polar.exceptions import PolarError from polar.logging import Logger -from polar.models import OAuthAccount, User -from polar.models.benefit import BenefitType +from polar.models import Customer, OAuthAccount, User +from polar.models.customer import CustomerOAuthAccount from polar.models.user import OAuthPlatform from polar.postgres import AsyncSession -from polar.worker import enqueue_job from . import oauth from .client import DiscordClient, bot_client @@ -34,6 +33,25 @@ def __init__(self, user: User) -> None: super().__init__(message, 401) +class DiscordCustomerAccountDoesNotExist(DiscordError): + def __init__(self, customer: Customer, account_id: str) -> None: + self.customer = customer + self.account_id = account_id + message = ( + f"The Discord account {account_id} does not exist " + f"on customer {customer.id}." + ) + super().__init__(message) + + +class DiscordCustomerExpiredAccessToken(DiscordError): + def __init__(self, customer: Customer, account_id: str) -> None: + self.customer = customer + self.account_id = account_id + message = "The access token is expired and no refresh token is available." + super().__init__(message, 401) + + class DiscordUserService: async def create_oauth_account( self, session: AsyncSession, user: User, oauth2_token_data: OAuth2Token @@ -60,13 +78,6 @@ async def create_oauth_account( session.add(oauth_account) await session.commit() - # Make sure potential Discord benefits are granted - enqueue_job( - "benefit.precondition_fulfilled", - user_id=user.id, - benefit_type=BenefitType.discord, - ) - return oauth_account async def update_user_info( @@ -115,6 +126,59 @@ async def get_oauth_account( return account +class DiscordCustomerService: + async def create_oauth_account( + self, session: AsyncSession, customer: Customer, oauth2_token_data: OAuth2Token + ) -> CustomerOAuthAccount: + access_token = oauth2_token_data["access_token"] + + client = DiscordClient("Bearer", access_token) + data = await client.get_me() + + account_id = data["id"] + oauth_account = CustomerOAuthAccount( + access_token=access_token, + expires_at=oauth2_token_data["expires_at"], + refresh_token=oauth2_token_data["refresh_token"], + account_id=data["id"], + ) + customer.set_oauth_account(self._get_account_key(account_id), oauth_account) + session.add(customer) + + return oauth_account + + async def get_oauth_account( + self, session: AsyncSession, customer: Customer, account_id: str + ) -> CustomerOAuthAccount: + account_key = self._get_account_key(account_id) + oauth_account = customer.get_oauth_account(account_key) + if oauth_account is None: + raise DiscordCustomerAccountDoesNotExist(customer, account_id) + + if oauth_account.is_expired(): + if oauth_account.refresh_token is None: + raise DiscordCustomerExpiredAccessToken(customer, account_id) + + log.debug( + "Refresh Discord access token", + oauth_account_id=oauth_account.account_id, + customer_id=str(customer.id), + ) + refreshed_token_data = await oauth.user_client.refresh_token( + oauth_account.refresh_token + ) + oauth_account.access_token = refreshed_token_data["access_token"] + oauth_account.expires_at = refreshed_token_data["expires_at"] + oauth_account.refresh_token = refreshed_token_data["refresh_token"] + customer.set_oauth_account(account_key, oauth_account) + session.add(customer) + + return oauth_account + + def _get_account_key(self, account_id: str) -> str: + return f"discord:{oauth.user_client.client_id}:{account_id}" + + class DiscordBotService: async def get_guild(self, id: str) -> DiscordGuild: guild = await bot_client.get_guild(id) @@ -138,13 +202,12 @@ async def get_guild(self, id: str) -> DiscordGuild: return DiscordGuild(name=guild["name"], roles=roles) async def add_member( - self, session: AsyncSession, guild_id: str, role_id: str, user: User + self, guild_id: str, role_id: str, account_id: str, access_token: str ) -> None: - oauth_account = await DiscordUserService().get_oauth_account(session, user) await bot_client.add_member( guild_id=guild_id, - discord_user_id=oauth_account.account_id, - discord_user_access_token=oauth_account.access_token, + discord_user_id=account_id, + discord_user_access_token=access_token, role_id=role_id, ) @@ -176,4 +239,5 @@ async def is_bot_role_above_role(self, guild_id: str, role_id: str) -> bool: discord_user = DiscordUserService() +discord_customer = DiscordCustomerService() discord_bot = DiscordBotService() diff --git a/server/polar/integrations/github/endpoints.py b/server/polar/integrations/github/endpoints.py index 912316d6f7..e561655946 100644 --- a/server/polar/integrations/github/endpoints.py +++ b/server/polar/integrations/github/endpoints.py @@ -32,7 +32,6 @@ from polar.kit.http import ReturnTo from polar.locker import Locker, get_locker from polar.models import ExternalOrganization -from polar.models.benefit import BenefitType from polar.openapi import APITag from polar.pledge.service import pledge as pledge_service from polar.postgres import AsyncSession, get_db_session @@ -181,13 +180,6 @@ async def github_callback( # connect dangling rewards await reward_service.connect_by_username(session, user) - # Make sure potential GitHub benefits are granted - enqueue_job( - "benefit.precondition_fulfilled", - user_id=user.id, - benefit_type=BenefitType.github_repository, - ) - # Event tracking last to ensure business critical data is stored first if is_signup: posthog.user_signup(user, "github") diff --git a/server/polar/models/benefit_grant.py b/server/polar/models/benefit_grant.py index e1e1ca293d..2b873f0ae5 100644 --- a/server/polar/models/benefit_grant.py +++ b/server/polar/models/benefit_grant.py @@ -80,17 +80,13 @@ class BenefitGrantAdsProperties(BenefitGrantPropertiesBase): class BenefitGrantDiscordProperties(BenefitGrantPropertiesBase, total=False): + account_id: str guild_id: str role_id: str - account_id: str class BenefitGrantGitHubRepositoryProperties(BenefitGrantPropertiesBase, total=False): - # repository_id was set previously (before 2024-13-15), for benefits using the "main" - # Polar GitHub App for granting benefits. Benefits created after this date are using - # the "Polar Repository Benefit" GitHub App, and only uses the repository_owner - # and repository_name fields. - repository_id: str | None + account_id: str repository_owner: str repository_name: str permission: Literal["pull", "triage", "push", "maintain", "admin"] diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index ed46b7734a..5544cc640e 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING +import dataclasses +import time +from typing import TYPE_CHECKING, Any from uuid import UUID from sqlalchemy import ( @@ -10,6 +12,7 @@ Uuid, func, ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from polar.kit.address import Address, AddressType @@ -21,6 +24,21 @@ from .organization import Organization +@dataclasses.dataclass +class CustomerOAuthAccount: + access_token: str + account_id: str + account_username: str | None = None + expires_at: int | None = None + refresh_token: str | None = None + refresh_token_expires_at: int | None = None + + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return time.time() > self.expires_at + + class Customer(MetadataMixin, RecordModel): __tablename__ = "customers" __table_args__ = ( @@ -45,6 +63,10 @@ class Customer(MetadataMixin, RecordModel): ) tax_id: Mapped[TaxID | None] = mapped_column(TaxIDType, nullable=True, default=None) + _oauth_accounts: Mapped[dict[str, dict[str, Any]]] = mapped_column( + "oauth_accounts", JSONB, nullable=False, default=dict + ) + organization_id: Mapped[UUID] = mapped_column( Uuid, ForeignKey("organizations.id", ondelete="cascade"), @@ -54,3 +76,18 @@ class Customer(MetadataMixin, RecordModel): @declared_attr def organization(cls) -> Mapped["Organization"]: return relationship("Organization", lazy="raise") + + def get_oauth_account(self, account_key: str) -> CustomerOAuthAccount | None: + oauth_account_data = self._oauth_accounts.get(account_key) + if oauth_account_data is None: + return None + + return CustomerOAuthAccount(**oauth_account_data) + + def set_oauth_account( + self, account_key: str, oauth_account: CustomerOAuthAccount + ) -> None: + self._oauth_accounts[account_key] = dataclasses.asdict(oauth_account) + + def remove_oauth_account(self, account_key: str) -> None: + self._oauth_accounts.pop(account_key, None) diff --git a/server/polar/notifications/notification.py b/server/polar/notifications/notification.py index 49ff12c1da..1fa31c1d69 100644 --- a/server/polar/notifications/notification.py +++ b/server/polar/notifications/notification.py @@ -1,10 +1,10 @@ from abc import abstractmethod from datetime import datetime from enum import StrEnum -from typing import Annotated, Any, Literal +from typing import Annotated, Literal from uuid import UUID -from pydantic import UUID4, BaseModel, Discriminator, Field +from pydantic import UUID4, BaseModel, Discriminator from polar.email.renderer import get_email_renderer from polar.kit.money import get_cents_in_dollar_string @@ -30,7 +30,6 @@ class NotificationType(StrEnum): maintainer_account_reviewed = "MaintainerAccountReviewedNotification" maintainer_new_paid_subscription = "MaintainerNewPaidSubscriptionNotification" maintainer_new_product_sale = "MaintainerNewProductSaleNotification" - benefit_precondition_error = "BenefitPreconditionErrorNotification" maintainer_create_account = "MaintainerCreateAccountNotification" @@ -461,32 +460,6 @@ class MaintainerNewProductSaleNotification(NotificationBase): payload: MaintainerNewProductSaleNotificationPayload -class BenefitPreconditionErrorNotificationContextualPayload(BaseModel): - extra_context: dict[str, Any] = Field(default_factory=dict) - subject_template: str - body_template: str - - -class BenefitPreconditionErrorNotificationPayload( - NotificationPayloadBase, BenefitPreconditionErrorNotificationContextualPayload -): - scope_name: str - benefit_id: UUID - benefit_description: str - organization_name: str - - def subject(self) -> str: - return self.subject_template.format(**self.model_dump()) - - def body(self) -> str: - return self.body_template.format(**self.model_dump()) - - -class BenefitPreconditionErrorNotification(NotificationBase): - type: Literal[NotificationType.benefit_precondition_error] - payload: BenefitPreconditionErrorNotificationPayload - - class MaintainerCreateAccountNotificationPayload(NotificationPayloadBase): organization_name: str url: str @@ -550,7 +523,6 @@ class MaintainerCreateAccountNotification(NotificationBase): | MaintainerAccountReviewedNotificationPayload | MaintainerNewPaidSubscriptionNotificationPayload | MaintainerNewProductSaleNotificationPayload - | BenefitPreconditionErrorNotificationPayload | MaintainerCreateAccountNotificationPayload ) @@ -568,7 +540,6 @@ class MaintainerCreateAccountNotification(NotificationBase): | MaintainerAccountReviewedNotification | MaintainerNewPaidSubscriptionNotification | MaintainerNewProductSaleNotification - | BenefitPreconditionErrorNotification | MaintainerCreateAccountNotification, Discriminator(discriminator="type"), ] diff --git a/server/tests/benefit/service/test_benefit_grant.py b/server/tests/benefit/service/test_benefit_grant.py index 6834a802df..4a05c56592 100644 --- a/server/tests/benefit/service/test_benefit_grant.py +++ b/server/tests/benefit/service/test_benefit_grant.py @@ -4,21 +4,11 @@ import pytest from pytest_mock import MockerFixture -from polar.benefit.benefits import BenefitPreconditionError, BenefitServiceProtocol -from polar.benefit.service.benefit_grant import ( - benefit_grant as benefit_grant_service, -) -from polar.benefit.service.benefit_grant import ( # type: ignore[attr-defined] - notification_service, -) +from polar.benefit.benefits import BenefitActionRequiredError, BenefitServiceProtocol +from polar.benefit.service.benefit_grant import benefit_grant as benefit_grant_service from polar.models import Benefit, BenefitGrant, Customer, Product, Subscription -from polar.notifications.notification import ( - BenefitPreconditionErrorNotificationContextualPayload, -) -from polar.notifications.service import NotificationsService from polar.postgres import AsyncSession from polar.redis import Redis -from polar.subscription.service import subscription as subscription_service from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( create_benefit_grant, @@ -119,7 +109,7 @@ async def test_existing_grant_already_granted( assert updated_grant.is_granted benefit_service_mock.grant.assert_not_called() - async def test_precondition_error( + async def test_action_required_error( self, session: AsyncSession, redis: Redis, @@ -128,7 +118,7 @@ async def test_precondition_error( benefit_organization: Benefit, benefit_service_mock: MagicMock, ) -> None: - benefit_service_mock.grant.side_effect = BenefitPreconditionError("Error") + benefit_service_mock.grant.side_effect = BenefitActionRequiredError("Error") # then session.expunge_all() @@ -582,7 +572,7 @@ async def test_granted_grant( benefit_service_mock.grant.assert_called_once() assert benefit_service_mock.grant.call_args[1]["update"] is True - async def test_precondition_error( + async def test_TODO_error( self, session: AsyncSession, redis: Redis, @@ -598,7 +588,7 @@ async def test_precondition_error( grant.set_granted() await save_fixture(grant) - benefit_service_mock.grant.side_effect = BenefitPreconditionError("Error") + benefit_service_mock.grant.side_effect = BenefitActionRequiredError("Error") # then session.expunge_all() @@ -716,112 +706,6 @@ async def test_granted_grant( benefit_service_mock.revoke.assert_called_once() -@pytest.fixture -def notification_send_to_user_mock(mocker: MockerFixture) -> MagicMock: - return mocker.patch.object( - notification_service, "send_to_user", spec=NotificationsService.send_to_user - ) - - -@pytest.mark.asyncio -class TestHandlePreconditionError: - async def test_no_notification( - self, - session: AsyncSession, - subscription: Subscription, - benefit_organization: Benefit, - customer: Customer, - notification_send_to_user_mock: MagicMock, - ) -> None: - error = BenefitPreconditionError("Error") - - # then - session.expunge_all() - - await benefit_grant_service.handle_precondition_error( - session, error, customer, benefit_organization, subscription=subscription - ) - - notification_send_to_user_mock.assert_not_called() - - async def test_email( - self, - session: AsyncSession, - subscription: Subscription, - benefit_organization: Benefit, - customer: Customer, - notification_send_to_user_mock: MagicMock, - ) -> None: - error = BenefitPreconditionError( - "Error", - payload=BenefitPreconditionErrorNotificationContextualPayload( - subject_template="Action required for granting {subscription_benefit_name}", - body_template="Go here to fix this: {extra_context[url]}", - extra_context={"url": "https://polar.sh"}, - ), - ) - - # then - session.expunge_all() - - # load - subscription_loaded = await subscription_service.get(session, subscription.id) - assert subscription_loaded - - await benefit_grant_service.handle_precondition_error( - session, - error, - customer, - benefit_organization, - subscription=subscription_loaded, - ) - - notification_send_to_user_mock.assert_not_called() - - -@pytest.mark.asyncio -class TestEnqueueGrantsAfterPreconditionFulfilled: - async def test_required_update( - self, - mocker: MockerFixture, - session: AsyncSession, - save_fixture: SaveFixture, - subscription: Subscription, - customer: Customer, - customer_second: Customer, - benefit_organization: Benefit, - ) -> None: - pending_grant = BenefitGrant( - subscription=subscription, customer=customer, benefit=benefit_organization - ) - await save_fixture(pending_grant) - - other_user_grant = BenefitGrant( - subscription=subscription, - customer=customer_second, - benefit=benefit_organization, - ) - await save_fixture(other_user_grant) - - enqueue_job_mock = mocker.patch( - "polar.benefit.service.benefit_grant.enqueue_job" - ) - - # then - session.expunge_all() - - await benefit_grant_service.enqueue_grants_after_precondition_fulfilled( - session, customer, benefit_organization.type - ) - - enqueue_job_mock.assert_called_once_with( - "benefit.grant", - customer_id=customer.id, - benefit_id=pending_grant.benefit_id, - **pending_grant.scope, - ) - - @pytest.mark.asyncio class TestGetByBenefitAndScope: async def test_existing_grant_incorrect_scope( diff --git a/server/tests/benefit/test_tasks.py b/server/tests/benefit/test_tasks.py index 07103aa5f2..a063c66c69 100644 --- a/server/tests/benefit/test_tasks.py +++ b/server/tests/benefit/test_tasks.py @@ -15,12 +15,10 @@ benefit_delete, benefit_grant, benefit_grant_service, - benefit_precondition_fulfilled, benefit_revoke, benefit_update, ) from polar.models import Benefit, BenefitGrant, Customer, Subscription -from polar.models.benefit import BenefitType from polar.postgres import AsyncSession from polar.worker import JobContext, PolarWorkerContext from tests.fixtures.database import SaveFixture @@ -378,46 +376,3 @@ async def test_retry( with pytest.raises(Retry): await benefit_delete(job_context, grant.id, polar_worker_context) - - -@pytest.mark.asyncio -class TestBenefitPreconditionFulfilled: - async def test_not_existing_customer( - self, - job_context: JobContext, - polar_worker_context: PolarWorkerContext, - session: AsyncSession, - ) -> None: - # then - session.expunge_all() - - with pytest.raises(CustomerDoesNotExist): - await benefit_precondition_fulfilled( - job_context, - uuid.uuid4(), - BenefitType.custom, - polar_worker_context, - ) - - async def test_existing_customer( - self, - mocker: MockerFixture, - job_context: JobContext, - polar_worker_context: PolarWorkerContext, - session: AsyncSession, - customer: Customer, - ) -> None: - enqueue_grants_after_precondition_fulfilled_mock = mocker.patch.object( - benefit_grant_service, - "enqueue_grants_after_precondition_fulfilled", - spec=BenefitGrantService.enqueue_grants_after_precondition_fulfilled, - ) - - # then - session.expunge_all() - - await benefit_precondition_fulfilled( - job_context, customer.id, BenefitType.custom, polar_worker_context - ) - - enqueue_grants_after_precondition_fulfilled_mock.assert_called_once() From 2084ffeea69072e2bf86b1450031e401ecb58a06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 5 Dec 2024 15:33:41 +0100 Subject: [PATCH 11/47] server: revamp users/ endpoints into /customer-portal endpoints --- server/polar/api.py | 3 + server/polar/auth/dependencies.py | 5 +- server/polar/auth/models.py | 13 +- server/polar/auth/scope.py | 17 +- server/polar/benefit/benefits/license_keys.py | 4 +- server/polar/benefit/schemas.py | 18 +-- server/polar/customer_portal/__init__.py | 0 server/polar/customer_portal/auth.py | 27 ++++ .../customer_portal/endpoints/__init__.py | 15 ++ .../endpoints/benefit_grant.py} | 60 ++++--- .../endpoints/downloadables.py | 21 +-- .../endpoints/license_keys.py | 23 +-- .../endpoints/order.py | 46 +++--- .../endpoints/subscription.py | 42 ++--- .../polar/customer_portal/schemas/__init__.py | 0 .../customer_portal/schemas/benefit_grant.py | 82 ++++++++++ .../schemas/downloadables.py | 0 .../schemas/order.py | 19 ++- .../schemas/subscription.py | 31 ++-- .../customer_portal/service/benefit_grant.py | 118 ++++++++++++++ .../service/downloadables.py | 27 +++- .../service/order.py | 43 ++--- .../service/subscription.py | 52 +++--- server/polar/license_key/schemas.py | 27 +++- server/polar/license_key/service.py | 142 ++++++++--------- server/polar/models/benefit.py | 2 +- server/polar/user/auth.py | 60 +------ server/polar/user/endpoints/__init__.py | 20 +-- server/polar/user/schemas/benefit.py | 13 -- server/polar/user/service/benefit.py | 148 ------------------ server/tests/customer_portal/__init__.py | 0 .../endpoints/test_benefit_grant.py} | 18 ++- .../endpoints/test_downloadables.py | 90 +++++------ .../endpoints/test_license_keys.py | 112 ++++++------- .../tests/customer_portal/service/__init__.py | 0 .../service/test_benefit_grant.py} | 115 +++++++------- .../service/test_order.py | 69 ++++---- .../service/test_subscription.py | 97 +++++++----- server/tests/file/test_endpoints.py | 2 +- server/tests/fixtures/auth.py | 8 +- server/tests/fixtures/file.py | 83 ++++------ server/tests/fixtures/license_key.py | 28 ++-- server/tests/license_key/test_endpoints.py | 2 +- 43 files changed, 854 insertions(+), 848 deletions(-) create mode 100644 server/polar/customer_portal/__init__.py create mode 100644 server/polar/customer_portal/auth.py create mode 100644 server/polar/customer_portal/endpoints/__init__.py rename server/polar/{user/endpoints/benefit.py => customer_portal/endpoints/benefit_grant.py} (55%) rename server/polar/{user => customer_portal}/endpoints/downloadables.py (82%) rename server/polar/{user => customer_portal}/endpoints/license_keys.py (89%) rename server/polar/{user => customer_portal}/endpoints/order.py (71%) rename server/polar/{user => customer_portal}/endpoints/subscription.py (78%) create mode 100644 server/polar/customer_portal/schemas/__init__.py create mode 100644 server/polar/customer_portal/schemas/benefit_grant.py rename server/polar/{user => customer_portal}/schemas/downloadables.py (100%) rename server/polar/{user => customer_portal}/schemas/order.py (63%) rename server/polar/{user => customer_portal}/schemas/subscription.py (55%) create mode 100644 server/polar/customer_portal/service/benefit_grant.py rename server/polar/{user => customer_portal}/service/downloadables.py (90%) rename server/polar/{user => customer_portal}/service/order.py (78%) rename server/polar/{user => customer_portal}/service/subscription.py (87%) delete mode 100644 server/polar/user/schemas/benefit.py delete mode 100644 server/polar/user/service/benefit.py create mode 100644 server/tests/customer_portal/__init__.py rename server/tests/{user/endpoints/test_benefits.py => customer_portal/endpoints/test_benefit_grant.py} (68%) rename server/tests/{user => customer_portal}/endpoints/test_downloadables.py (82%) rename server/tests/{user => customer_portal}/endpoints/test_license_keys.py (87%) create mode 100644 server/tests/customer_portal/service/__init__.py rename server/tests/{user/service/test_benefit.py => customer_portal/service/test_benefit_grant.py} (57%) rename server/tests/{user => customer_portal}/service/test_order.py (50%) rename server/tests/{user => customer_portal}/service/test_subscription.py (76%) diff --git a/server/polar/api.py b/server/polar/api.py index 9fef52e570..38ad86f917 100644 --- a/server/polar/api.py +++ b/server/polar/api.py @@ -9,6 +9,7 @@ from polar.checkout.legacy.endpoints import router as checkout_legacy_router from polar.checkout_link.endpoints import router as checkout_link_router from polar.custom_field.endpoints import router as custom_field_router +from polar.customer_portal.endpoints import router as customer_portal_router from polar.dashboard.endpoints import router as dashboard_router from polar.discount.endpoints import router as discount_router from polar.embed.endpoints import router as embed_router @@ -126,3 +127,5 @@ router.include_router(embed_router) # /discounts router.include_router(discount_router) +# /customer-portal +router.include_router(customer_portal_router) diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index 752e203105..e59a9ce2e1 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -14,7 +14,6 @@ from polar.sentry import set_sentry_user from .models import ( - SUBJECTS, Anonymous, AuthMethod, AuthSubject, @@ -78,7 +77,7 @@ class _Authenticator: def __init__( self, *, - allowed_subjects: set[SubjectType] = SUBJECTS, + allowed_subjects: set[SubjectType], required_scopes: set[Scope] | None = None, ) -> None: self.allowed_subjects = allowed_subjects @@ -121,7 +120,7 @@ async def __call__( def Authenticator( - allowed_subjects: set[SubjectType] = SUBJECTS, + allowed_subjects: set[SubjectType], required_scopes: set[Scope] | None = None, ) -> _Authenticator: """ diff --git a/server/polar/auth/models.py b/server/polar/auth/models.py index eacd4a787d..e27cd6a52c 100644 --- a/server/polar/auth/models.py +++ b/server/polar/auth/models.py @@ -1,7 +1,7 @@ from enum import Enum, auto from typing import Generic, TypeGuard, TypeVar -from polar.models import Organization, User +from polar.models import Customer, Organization, User from .scope import Scope @@ -9,9 +9,8 @@ class Anonymous: ... -Subject = User | Organization | Anonymous -SubjectType = type[User] | type[Organization] | type[Anonymous] -SUBJECTS: set[SubjectType] = {User, Organization, Anonymous} +Subject = User | Organization | Customer | Anonymous +SubjectType = type[User] | type[Organization] | type[Customer] | type[Anonymous] class AuthMethod(Enum): @@ -66,10 +65,13 @@ def is_organization( return isinstance(auth_subject.subject, Organization) +def is_customer(auth_subject: AuthSubject[S]) -> TypeGuard[AuthSubject[Customer]]: + return isinstance(auth_subject.subject, Customer) + + __all__ = [ "Subject", "SubjectType", - "SUBJECTS", "AuthMethod", "AuthSubject", "is_anonymous", @@ -79,4 +81,5 @@ def is_organization( "Anonymous", "User", "Organization", + "Customer", ] diff --git a/server/polar/auth/scope.py b/server/polar/auth/scope.py index aa756639a3..92e90fa9ae 100644 --- a/server/polar/auth/scope.py +++ b/server/polar/auth/scope.py @@ -62,15 +62,12 @@ class Scope(StrEnum): issues_read = "issues:read" issues_write = "issues:write" - user_benefits_read = "user:benefits:read" - user_orders_read = "user:orders:read" - user_subscriptions_read = "user:subscriptions:read" - user_subscriptions_write = "user:subscriptions:write" - user_downloadables_read = "user:downloadables:read" - user_license_keys_read = "user:license_keys:read" user_advertisement_campaigns_read = "user:advertisement_campaigns:read" user_advertisement_campaigns_write = "user:advertisement_campaigns:write" + customer_portal_read = "customer_portal:read" + customer_portal_write = "customer_portal:write" + @classmethod def __get_pydantic_json_schema__( cls, core_schema: cs.CoreSchema, handler: GetJsonSchemaHandler @@ -117,12 +114,8 @@ def __get_pydantic_json_schema__( Scope.license_keys_read: "Read license keys", Scope.license_keys_write: "Modify license keys", Scope.webhooks_write: "Create or modify webhooks", - Scope.user_benefits_read: "Read your granted benefits", - Scope.user_orders_read: "Read your orders", - Scope.user_subscriptions_read: "Read your subscriptions", - Scope.user_subscriptions_write: "Create or modify your subscriptions", - Scope.user_downloadables_read: "Read your downloadable files", - Scope.user_license_keys_read: "Read license keys you have access to", + Scope.customer_portal_read: "Read your orders, subscriptions and benefits", + Scope.customer_portal_write: "Create or modify your orders, subscriptions and benefits", Scope.user_advertisement_campaigns_read: "Read your advertisement campaigns", Scope.user_advertisement_campaigns_write: ( "Create or modify your advertisement campaigns" diff --git a/server/polar/benefit/benefits/license_keys.py b/server/polar/benefit/benefits/license_keys.py index 245e434ca1..a4e59351ce 100644 --- a/server/polar/benefit/benefits/license_keys.py +++ b/server/polar/benefit/benefits/license_keys.py @@ -41,7 +41,7 @@ async def grant( if update and "license_key_id" in grant_properties: current_lk_id = UUID(grant_properties["license_key_id"]) - key = await license_key_service.user_grant( + key = await license_key_service.customer_grant( self.session, customer=customer, benefit=benefit, @@ -70,7 +70,7 @@ async def revoke( ) return grant_properties - await license_key_service.user_revoke( + await license_key_service.customer_revoke( self.session, customer=customer, benefit=benefit, diff --git a/server/polar/benefit/schemas.py b/server/polar/benefit/schemas.py index b03b13d82f..0e57844de5 100644 --- a/server/polar/benefit/schemas.py +++ b/server/polar/benefit/schemas.py @@ -1,4 +1,3 @@ -from collections.abc import Sequence from datetime import datetime from typing import Annotated, Any, Literal @@ -26,7 +25,6 @@ ) from polar.models.benefit import BenefitType from polar.models.benefit_grant import ( - BenefitGrantLicenseKeysProperties, BenefitGrantProperties, ) from polar.organization.schemas import Organization, OrganizationID @@ -254,7 +252,7 @@ class BenefitLicenseKeyExpirationProperties(Schema): class BenefitLicenseKeyActivationProperties(Schema): limit: int = Field(gt=0, le=50) - enable_user_admin: bool + enable_customer_admin: bool class BenefitLicenseKeysCreateProperties(Schema): @@ -549,11 +547,7 @@ class BenefitGrantWebhook(BenefitGrant): # BenefitSubscriber -class BenefitGrantSubscriber(BenefitGrantBase): ... - - class BenefitSubscriberBase(BenefitBase): - grants: Sequence[BenefitGrantSubscriber] organization: Organization @@ -569,14 +563,9 @@ class BenefitGrantAdsSubscriberProperties(Schema): ) -class BenefitGrantAds(BenefitGrantSubscriber): - properties: BenefitGrantAdsSubscriberProperties - - class BenefitAdsSubscriber(BenefitSubscriberBase): type: Literal[BenefitType.ads] properties: BenefitAdsProperties - grants: Sequence[BenefitGrantAds] class BenefitDiscordSubscriber(BenefitSubscriberBase): @@ -594,14 +583,9 @@ class BenefitDownloadablesSubscriber(BenefitSubscriberBase): properties: BenefitDownloadablesSubscriberProperties -class BenefitGrantLicenseKeys(BenefitGrantSubscriber): - properties: BenefitGrantLicenseKeysProperties - - class BenefitLicenseKeysSubscriber(BenefitSubscriberBase): type: Literal[BenefitType.license_keys] properties: BenefitLicenseKeysSubscriberProperties - grants: Sequence[BenefitGrantLicenseKeys] # Properties that are available to subscribers only diff --git a/server/polar/customer_portal/__init__.py b/server/polar/customer_portal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/polar/customer_portal/auth.py b/server/polar/customer_portal/auth.py new file mode 100644 index 0000000000..2fb276fa52 --- /dev/null +++ b/server/polar/customer_portal/auth.py @@ -0,0 +1,27 @@ +from typing import Annotated + +from fastapi import Depends + +from polar.auth.dependencies import Authenticator +from polar.auth.models import AuthSubject, Customer, User +from polar.auth.scope import Scope + +_CustomerPortalRead = Authenticator( + required_scopes={ + Scope.web_default, + Scope.customer_portal_read, + Scope.custom_fields_write, + }, + allowed_subjects={User, Customer}, +) +CustomerPortalRead = Annotated[ + AuthSubject[User | Customer], Depends(_CustomerPortalRead) +] + +_CustomerPortalWrite = Authenticator( + required_scopes={Scope.web_default, Scope.customer_portal_write}, + allowed_subjects={User, Customer}, +) +CustomerPortalWrite = Annotated[ + AuthSubject[User | Customer], Depends(_CustomerPortalWrite) +] diff --git a/server/polar/customer_portal/endpoints/__init__.py b/server/polar/customer_portal/endpoints/__init__.py new file mode 100644 index 0000000000..c689438adf --- /dev/null +++ b/server/polar/customer_portal/endpoints/__init__.py @@ -0,0 +1,15 @@ +from polar.routing import APIRouter + +from .benefit_grant import router as benefit_grant_router +from .downloadables import router as downloadables_router +from .license_keys import router as license_keys_router +from .order import router as order_router +from .subscription import router as subscription_router + +router = APIRouter(prefix="/customer-portal", tags=["customer_portal"]) + +router.include_router(benefit_grant_router) +router.include_router(order_router) +router.include_router(subscription_router) +router.include_router(downloadables_router) +router.include_router(license_keys_router) diff --git a/server/polar/user/endpoints/benefit.py b/server/polar/customer_portal/endpoints/benefit_grant.py similarity index 55% rename from server/polar/user/endpoints/benefit.py rename to server/polar/customer_portal/endpoints/benefit_grant.py index bbd3cdff2a..12d96faccb 100644 --- a/server/polar/user/endpoints/benefit.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -8,7 +8,7 @@ from polar.kit.pagination import ListResource, PaginationParamsQuery from polar.kit.schemas import MultipleQueryFilter from polar.kit.sorting import Sorting, SortingGetter -from polar.models import Benefit +from polar.models import BenefitGrant from polar.models.benefit import BenefitType from polar.openapi import APITag from polar.organization.schemas import OrganizationID @@ -16,29 +16,35 @@ from polar.routing import APIRouter from .. import auth -from ..schemas.benefit import UserBenefit, UserBenefitAdapter -from ..service.benefit import UserBenefitSortProperty -from ..service.benefit import user_benefit as user_benefit_service +from ..schemas.benefit_grant import BenefitGrant as BenefitGrantSchema +from ..schemas.benefit_grant import BenefitGrantAdapter +from ..service.benefit_grant import CustomerBenefitGrantSortProperty +from ..service.benefit_grant import ( + customer_benefit_grant as customer_benefit_grant_service, +) router = APIRouter( - prefix="/benefits", tags=["benefits", APITag.documented, APITag.featured] + prefix="/benefit-grants", + tags=["benefit-grants", APITag.documented], ) -BenefitID = Annotated[UUID4, Path(description="The benefit ID.")] -BenefitNotFound = { - "description": "Benefit not found or not granted.", +BenefitGrantID = Annotated[UUID4, Path(description="The benefit grant ID.")] +BenefitGrantNotFound = { + "description": "Benefit grant not found.", "model": ResourceNotFound.schema(), } ListSorting = Annotated[ - list[Sorting[UserBenefitSortProperty]], - Depends(SortingGetter(UserBenefitSortProperty, ["-granted_at"])), + list[Sorting[CustomerBenefitGrantSortProperty]], + Depends(SortingGetter(CustomerBenefitGrantSortProperty, ["-granted_at"])), ] -@router.get("/", summary="List Benefits", response_model=ListResource[UserBenefit]) +@router.get( + "/", summary="List Benefit Grants", response_model=ListResource[BenefitGrantSchema] +) async def list( - auth_subject: auth.UserBenefitsRead, + auth_subject: auth.CustomerPortalRead, pagination: PaginationParamsQuery, sorting: ListSorting, type: MultipleQueryFilter[BenefitType] | None = Query( @@ -54,9 +60,9 @@ async def list( None, title="SubscriptionID Filter", description="Filter by subscription ID." ), session: AsyncSession = Depends(get_db_session), -) -> ListResource[UserBenefit]: - """List my granted benefits.""" - results, count = await user_benefit_service.list( +) -> ListResource[BenefitGrantSchema]: + """List benefits grants of the authenticated customer or user.""" + results, count = await customer_benefit_grant_service.list( session, auth_subject, type=type, @@ -68,7 +74,7 @@ async def list( ) return ListResource.from_paginated_results( - [UserBenefitAdapter.validate_python(result) for result in results], + [BenefitGrantAdapter.validate_python(result) for result in results], count, pagination, ) @@ -76,19 +82,21 @@ async def list( @router.get( "/{id}", - summary="Get Benefit", - response_model=UserBenefit, - responses={404: BenefitNotFound}, + summary="Get Benefit Grant", + response_model=BenefitGrantSchema, + responses={404: BenefitGrantNotFound}, ) async def get( - id: BenefitID, - auth_subject: auth.UserBenefitsRead, + id: BenefitGrantID, + auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), -) -> Benefit: - """Get a granted benefit by ID.""" - benefit = await user_benefit_service.get_by_id(session, auth_subject, id) +) -> BenefitGrant: + """Get a benefit grant by ID for the authenticated customer or user.""" + benefit_grant = await customer_benefit_grant_service.get_by_id( + session, auth_subject, id + ) - if benefit is None: + if benefit_grant is None: raise ResourceNotFound() - return benefit + return benefit_grant diff --git a/server/polar/user/endpoints/downloadables.py b/server/polar/customer_portal/endpoints/downloadables.py similarity index 82% rename from server/polar/user/endpoints/downloadables.py rename to server/polar/customer_portal/endpoints/downloadables.py index d7bd1d84d1..032ea70f4d 100644 --- a/server/polar/user/endpoints/downloadables.py +++ b/server/polar/customer_portal/endpoints/downloadables.py @@ -13,9 +13,7 @@ from ..schemas.downloadables import DownloadableRead from ..service.downloadables import downloadable as downloadable_service -router = APIRouter( - prefix="/downloadables", tags=["downloadables", APITag.documented, APITag.featured] -) +router = APIRouter(prefix="/downloadables", tags=["downloadables", APITag.documented]) @router.get( @@ -24,23 +22,19 @@ response_model=ListResource[DownloadableRead], ) async def list( - auth_subject: auth.UserDownloadablesRead, + auth_subject: auth.CustomerPortalRead, pagination: PaginationParamsQuery, organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." ), benefit_id: MultipleQueryFilter[BenefitID] | None = Query( - None, - title="BenefitID Filter", - description=("Filter by given benefit ID. "), + None, title="BenefitID Filter", description="Filter by benefit ID." ), session: AsyncSession = Depends(get_db_session), ) -> ListResource[DownloadableRead]: - subject = auth_subject.subject - results, count = await downloadable_service.get_list( session, - user=subject, + auth_subject, pagination=pagination, organization_id=organization_id, benefit_id=benefit_id, @@ -62,16 +56,15 @@ async def list( 404: {"description": "Downloadable not found"}, 410: {"description": "Expired signature"}, }, + name="customer_portal.downloadables.get", ) async def get( token: str, - auth_subject: auth.UserDownloadablesRead, + auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - subject = auth_subject.subject - downloadable = await downloadable_service.get_from_token_or_raise( - session, user=subject, token=token + session, auth_subject, token=token ) signed = downloadable_service.generate_download_schema(downloadable) return RedirectResponse(signed.file.download.url, 302) diff --git a/server/polar/user/endpoints/license_keys.py b/server/polar/customer_portal/endpoints/license_keys.py similarity index 89% rename from server/polar/user/endpoints/license_keys.py rename to server/polar/customer_portal/endpoints/license_keys.py index c183c690d6..09243a6583 100644 --- a/server/polar/user/endpoints/license_keys.py +++ b/server/polar/customer_portal/endpoints/license_keys.py @@ -2,7 +2,7 @@ from pydantic import UUID4 from polar.benefit.schemas import BenefitID -from polar.exceptions import NotPermitted, ResourceNotFound, Unauthorized +from polar.exceptions import NotPermitted, ResourceNotFound from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import ListResource, PaginationParamsQuery from polar.kit.schemas import MultipleQueryFilter @@ -48,7 +48,7 @@ }, ) async def list( - auth_subject: auth.UserLicenseKeysRead, + auth_subject: auth.CustomerPortalRead, pagination: PaginationParamsQuery, organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." @@ -58,9 +58,9 @@ async def list( ), session: AsyncSession = Depends(get_db_session), ) -> ListResource[LicenseKeyRead]: - results, count = await license_key_service.get_user_list( + results, count = await license_key_service.get_customer_list( session, - user=auth_subject.subject, + auth_subject, organization_ids=organization_id, benefit_id=benefit_id, pagination=pagination, @@ -77,28 +77,21 @@ async def list( "/{id}", summary="Get License Key", response_model=LicenseKeyWithActivations, - responses={ - 401: UnauthorizedResponse, - 404: NotFoundResponse, - }, + responses={404: NotFoundResponse}, ) async def get( - auth_subject: auth.UserLicenseKeysRead, + auth_subject: auth.CustomerPortalRead, id: UUID4, session: AsyncSession = Depends(get_db_session), ) -> LicenseKeyWithActivations: """Get a license key.""" - lk = await license_key_service.get_loaded(session, id) + lk = await license_key_service.get_customer_license_key(session, auth_subject, id) if not lk: raise ResourceNotFound() - user_id = auth_subject.subject.id - if user_id != lk.user_id: - raise Unauthorized() - ret = LicenseKeyWithActivations.model_validate(lk) activations = lk.benefit.properties.get("activations") - if not (activations and activations.get("enable_user_admin")): + if not (activations and activations.get("enable_customer_admin")): ret.activations = [] return ret diff --git a/server/polar/user/endpoints/order.py b/server/polar/customer_portal/endpoints/order.py similarity index 71% rename from server/polar/user/endpoints/order.py rename to server/polar/customer_portal/endpoints/order.py index 4dca4289a8..0ec6e67c3f 100644 --- a/server/polar/user/endpoints/order.py +++ b/server/polar/customer_portal/endpoints/order.py @@ -17,26 +17,24 @@ from polar.routing import APIRouter from .. import auth -from ..schemas.order import UserOrder, UserOrderInvoice -from ..service.order import UserOrderSortProperty -from ..service.order import user_order as user_order_service +from ..schemas.order import CustomerOrder, CustomerOrderInvoice +from ..service.order import CustomerOrderSortProperty +from ..service.order import customer_order as customer_order_service -router = APIRouter( - prefix="/orders", tags=["orders", APITag.documented, APITag.featured] -) +router = APIRouter(prefix="/orders", tags=["orders", APITag.documented]) OrderID = Annotated[UUID4, Path(description="The order ID.")] OrderNotFound = {"description": "Order not found.", "model": ResourceNotFound.schema()} ListSorting = Annotated[ - list[Sorting[UserOrderSortProperty]], - Depends(SortingGetter(UserOrderSortProperty, ["-created_at"])), + list[Sorting[CustomerOrderSortProperty]], + Depends(SortingGetter(CustomerOrderSortProperty, ["-created_at"])), ] -@router.get("/", summary="List Orders", response_model=ListResource[UserOrder]) +@router.get("/", summary="List Orders", response_model=ListResource[CustomerOrder]) async def list( - auth_subject: auth.UserOrdersRead, + auth_subject: auth.CustomerPortalRead, pagination: PaginationParamsQuery, sorting: ListSorting, organization_id: MultipleQueryFilter[OrganizationID] | None = Query( @@ -62,9 +60,9 @@ async def list( None, description="Search by product or organization name." ), session: AsyncSession = Depends(get_db_session), -) -> ListResource[UserOrder]: - """List my orders.""" - results, count = await user_order_service.list( +) -> ListResource[CustomerOrder]: + """List orders of the authenticated customer or user.""" + results, count = await customer_order_service.list( session, auth_subject, organization_id=organization_id, @@ -77,7 +75,7 @@ async def list( ) return ListResource.from_paginated_results( - [UserOrder.model_validate(result) for result in results], + [CustomerOrder.model_validate(result) for result in results], count, pagination, ) @@ -86,16 +84,16 @@ async def list( @router.get( "/{id}", summary="Get Order", - response_model=UserOrder, + response_model=CustomerOrder, responses={404: OrderNotFound}, ) async def get( id: OrderID, - auth_subject: auth.UserOrdersRead, + auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), ) -> Order: - """Get an order by ID.""" - order = await user_order_service.get_by_id(session, auth_subject, id) + """Get an order by ID for the authenticated customer or user.""" + order = await customer_order_service.get_by_id(session, auth_subject, id) if order is None: raise ResourceNotFound() @@ -106,20 +104,20 @@ async def get( @router.get( "/{id}/invoice", summary="Get Order Invoice", - response_model=UserOrderInvoice, + response_model=CustomerOrderInvoice, responses={404: OrderNotFound}, ) async def invoice( id: OrderID, - auth_subject: auth.UserOrdersRead, + auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), -) -> UserOrderInvoice: +) -> CustomerOrderInvoice: """Get an order's invoice data.""" - order = await user_order_service.get_by_id(session, auth_subject, id) + order = await customer_order_service.get_by_id(session, auth_subject, id) if order is None: raise ResourceNotFound() - invoice_url = await user_order_service.get_order_invoice_url(order) + invoice_url = await customer_order_service.get_order_invoice_url(order) - return UserOrderInvoice(url=invoice_url) + return CustomerOrderInvoice(url=invoice_url) diff --git a/server/polar/user/endpoints/subscription.py b/server/polar/customer_portal/endpoints/subscription.py similarity index 78% rename from server/polar/user/endpoints/subscription.py rename to server/polar/customer_portal/endpoints/subscription.py index 967834a341..cc392b4c00 100644 --- a/server/polar/user/endpoints/subscription.py +++ b/server/polar/customer_portal/endpoints/subscription.py @@ -17,14 +17,14 @@ from .. import auth from ..schemas.subscription import ( - UserSubscription, - UserSubscriptionUpdate, + CustomerSubscription, + CustomerSubscriptionUpdate, ) from ..service.subscription import ( AlreadyCanceledSubscription, - UserSubscriptionSortProperty, + CustomerSubscriptionSortProperty, ) -from ..service.subscription import user_subscription as user_subscription_service +from ..service.subscription import customer_subscription as user_subscription_service router = APIRouter( prefix="/subscriptions", tags=["subscriptions", APITag.documented, APITag.featured] @@ -37,16 +37,16 @@ } ListSorting = Annotated[ - list[Sorting[UserSubscriptionSortProperty]], - Depends(SortingGetter(UserSubscriptionSortProperty, ["-started_at"])), + list[Sorting[CustomerSubscriptionSortProperty]], + Depends(SortingGetter(CustomerSubscriptionSortProperty, ["-started_at"])), ] @router.get( - "/", summary="List Subscriptions", response_model=ListResource[UserSubscription] + "/", summary="List Subscriptions", response_model=ListResource[CustomerSubscription] ) async def list( - auth_subject: auth.UserSubscriptionsRead, + auth_subject: auth.CustomerPortalRead, pagination: PaginationParamsQuery, sorting: ListSorting, organization_id: MultipleQueryFilter[OrganizationID] | None = Query( @@ -63,8 +63,8 @@ async def list( None, description="Search by product or organization name." ), session: AsyncSession = Depends(get_db_session), -) -> ListResource[UserSubscription]: - """List my subscriptions.""" +) -> ListResource[CustomerSubscription]: + """List subscriptions of the authenticated customer or user.""" results, count = await user_subscription_service.list( session, auth_subject, @@ -77,7 +77,7 @@ async def list( ) return ListResource.from_paginated_results( - [UserSubscription.model_validate(result) for result in results], + [CustomerSubscription.model_validate(result) for result in results], count, pagination, ) @@ -86,15 +86,15 @@ async def list( @router.get( "/{id}", summary="Get Subscription", - response_model=UserSubscription, + response_model=CustomerSubscription, responses={404: SubscriptionNotFound}, ) async def get( id: SubscriptionID, - auth_subject: auth.UserSubscriptionsRead, + auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), ) -> Subscription: - """Get a subscription by ID.""" + """Get a subscription for the authenticated customer or user.""" subscription = await user_subscription_service.get_by_id(session, auth_subject, id) if subscription is None: @@ -106,7 +106,7 @@ async def get( @router.patch( "/{id}", summary="Update Subscription", - response_model=UserSubscription, + response_model=CustomerSubscription, responses={ 200: {"description": "Subscription updated."}, 404: SubscriptionNotFound, @@ -114,11 +114,11 @@ async def get( ) async def update( id: SubscriptionID, - subscription_update: UserSubscriptionUpdate, - auth_subject: auth.UserSubscriptionsWrite, + subscription_update: CustomerSubscriptionUpdate, + auth_subject: auth.CustomerPortalWrite, session: AsyncSession = Depends(get_db_session), ) -> Subscription: - """Update a subscription.""" + """Update a subscription of the authenticated customer or user.""" subscription = await user_subscription_service.get_by_id(session, auth_subject, id) if subscription is None: @@ -132,7 +132,7 @@ async def update( @router.delete( "/{id}", summary="Cancel Subscription", - response_model=UserSubscription, + response_model=CustomerSubscription, responses={ 200: {"description": "Subscription canceled."}, 403: { @@ -147,10 +147,10 @@ async def update( ) async def cancel( id: SubscriptionID, - auth_subject: auth.UserSubscriptionsWrite, + auth_subject: auth.CustomerPortalWrite, session: AsyncSession = Depends(get_db_session), ) -> Subscription: - """Cancel a subscription.""" + """Cancel a subscription of the authenticated customer or user.""" subscription = await user_subscription_service.get_by_id(session, auth_subject, id) if subscription is None: diff --git a/server/polar/customer_portal/schemas/__init__.py b/server/polar/customer_portal/schemas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/polar/customer_portal/schemas/benefit_grant.py b/server/polar/customer_portal/schemas/benefit_grant.py new file mode 100644 index 0000000000..4418c8f3b5 --- /dev/null +++ b/server/polar/customer_portal/schemas/benefit_grant.py @@ -0,0 +1,82 @@ +from datetime import datetime +from typing import Annotated + +from pydantic import UUID4, TypeAdapter + +from polar.benefit.schemas import ( + BenefitAdsSubscriber, + BenefitCustomSubscriber, + BenefitDiscordSubscriber, + BenefitDownloadablesSubscriber, + BenefitGitHubRepositorySubscriber, + BenefitLicenseKeysSubscriber, + BenefitSubscriber, +) +from polar.kit.schemas import IDSchema, MergeJSONSchema, TimestampedSchema +from polar.models.benefit_grant import ( + BenefitGrantAdsProperties, + BenefitGrantCustomProperties, + BenefitGrantDiscordProperties, + BenefitGrantDownloadablesProperties, + BenefitGrantGitHubRepositoryProperties, + BenefitGrantLicenseKeysProperties, +) + + +class BenefitGrantBase(IDSchema, TimestampedSchema): + granted_at: datetime | None + revoked_at: datetime | None + customer_id: UUID4 + benefit_id: UUID4 + subscription_id: UUID4 | None + order_id: UUID4 | None + is_granted: bool + is_revoked: bool + + +BenefitCustomer = Annotated[ + BenefitSubscriber, + MergeJSONSchema({"title": "BenefitCustomer"}), +] + + +class BenefitGrantDiscord(BenefitGrantBase): + benefit: BenefitDiscordSubscriber + properties: BenefitGrantDiscordProperties + + +class BenefitGrantGitHubRepository(BenefitGrantBase): + benefit: BenefitGitHubRepositorySubscriber + properties: BenefitGrantGitHubRepositoryProperties + + +class BenefitGrantDownloadables(BenefitGrantBase): + benefit: BenefitDownloadablesSubscriber + properties: BenefitGrantDownloadablesProperties + + +class BenefitGrantLicenseKeys(BenefitGrantBase): + benefit: BenefitLicenseKeysSubscriber + properties: BenefitGrantLicenseKeysProperties + + +class BenefitGrantAds(BenefitGrantBase): + benefit: BenefitAdsSubscriber + properties: BenefitGrantAdsProperties + + +class BenefitGrantCustomer(BenefitGrantBase): + benefit: BenefitCustomSubscriber + properties: BenefitGrantCustomProperties + + +BenefitGrant = Annotated[ + BenefitGrantDiscord + | BenefitGrantGitHubRepository + | BenefitGrantDownloadables + | BenefitGrantLicenseKeys + | BenefitGrantAds + | BenefitGrantCustomer, + MergeJSONSchema({"title": "BenefitGrant"}), +] +BenefitGrantAdapter: TypeAdapter[BenefitGrant] = TypeAdapter(BenefitGrant) diff --git a/server/polar/user/schemas/downloadables.py b/server/polar/customer_portal/schemas/downloadables.py similarity index 100% rename from server/polar/user/schemas/downloadables.py rename to server/polar/customer_portal/schemas/downloadables.py diff --git a/server/polar/user/schemas/order.py b/server/polar/customer_portal/schemas/order.py similarity index 63% rename from server/polar/user/schemas/order.py rename to server/polar/customer_portal/schemas/order.py index 23f05541ae..6ce740d208 100644 --- a/server/polar/user/schemas/order.py +++ b/server/polar/customer_portal/schemas/order.py @@ -12,35 +12,38 @@ from polar.subscription.schemas import SubscriptionBase -class UserOrderBase(TimestampedSchema): +class CustomerOrderBase(TimestampedSchema): id: UUID4 amount: int tax_amount: int currency: str - user_id: UUID4 + customer_id: UUID4 + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) product_id: UUID4 product_price_id: UUID4 subscription_id: UUID4 | None -class UserOrderProduct(ProductBase): +class CustomerOrderProduct(ProductBase): prices: ProductPriceList benefits: BenefitPublicList medias: ProductMediaList organization: Organization -class UserOrderSubscription(SubscriptionBase): ... +class CustomerOrderSubscription(SubscriptionBase): ... -class UserOrder(UserOrderBase): - product: UserOrderProduct +class CustomerOrder(CustomerOrderBase): + product: CustomerOrderProduct product_price: ProductPrice - subscription: UserOrderSubscription | None + subscription: CustomerOrderSubscription | None -class UserOrderInvoice(Schema): +class CustomerOrderInvoice(Schema): """Order's invoice data.""" url: str = Field(..., description="The URL to the invoice.") diff --git a/server/polar/user/schemas/subscription.py b/server/polar/customer_portal/schemas/subscription.py similarity index 55% rename from server/polar/user/schemas/subscription.py rename to server/polar/customer_portal/schemas/subscription.py index 562926dacc..cc9f64a71a 100644 --- a/server/polar/user/schemas/subscription.py +++ b/server/polar/customer_portal/schemas/subscription.py @@ -2,7 +2,7 @@ from pydantic import UUID4, Field -from polar.kit.schemas import EmailStrDNS, Schema +from polar.kit.schemas import Schema from polar.models.subscription import SubscriptionStatus from polar.organization.schemas import Organization from polar.product.schemas import ( @@ -15,7 +15,7 @@ from polar.subscription.schemas import SubscriptionBase -class UserSubscriptionBase(SubscriptionBase): +class CustomerSubscriptionBase(SubscriptionBase): status: SubscriptionStatus current_period_start: datetime current_period_end: datetime | None @@ -23,36 +23,25 @@ class UserSubscriptionBase(SubscriptionBase): started_at: datetime | None ended_at: datetime | None - user_id: UUID4 + customer_id: UUID4 + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) product_id: UUID4 price_id: UUID4 -class UserSubscriptionProduct(ProductBase): +class CustomerSubscriptionProduct(ProductBase): prices: ProductPriceList benefits: BenefitPublicList medias: ProductMediaList organization: Organization -class UserSubscription(UserSubscriptionBase): - product: UserSubscriptionProduct +class CustomerSubscription(CustomerSubscriptionBase): + product: CustomerSubscriptionProduct price: ProductPrice -class UserFreeSubscriptionCreate(Schema): - product_id: UUID4 = Field( - ..., - description="ID of the free tier to subscribe to.", - ) - customer_email: EmailStrDNS | None = Field( - None, - description=( - "Email of the customer. " - "This field is required if the API is called outside the Polar app." - ), - ) - - -class UserSubscriptionUpdate(Schema): +class CustomerSubscriptionUpdate(Schema): product_price_id: UUID4 diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py new file mode 100644 index 0000000000..f40afec1a4 --- /dev/null +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -0,0 +1,118 @@ +import uuid +from collections.abc import Sequence +from enum import StrEnum +from typing import Any + +from sqlalchemy import Select, UnaryExpression, asc, desc, select +from sqlalchemy.orm import contains_eager + +from polar.auth.models import AuthSubject, is_customer, is_user +from polar.kit.db.postgres import AsyncSession +from polar.kit.pagination import PaginationParams, paginate +from polar.kit.services import ResourceServiceReader +from polar.kit.sorting import Sorting +from polar.models import ( + Benefit, + BenefitGrant, + Customer, + Organization, + User, +) +from polar.models.benefit import BenefitType + + +class CustomerBenefitGrantSortProperty(StrEnum): + granted_at = "granted_at" + type = "type" + organization = "organization" + + +class CustomerBenefitGrantService(ResourceServiceReader[BenefitGrant]): + async def list( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Customer], + *, + type: Sequence[BenefitType] | None = None, + benefit_id: Sequence[uuid.UUID] | None = None, + organization_id: Sequence[uuid.UUID] | None = None, + order_id: Sequence[uuid.UUID] | None = None, + subscription_id: Sequence[uuid.UUID] | None = None, + pagination: PaginationParams, + sorting: list[Sorting[CustomerBenefitGrantSortProperty]] = [ + (CustomerBenefitGrantSortProperty.granted_at, True) + ], + ) -> tuple[Sequence[BenefitGrant], int]: + statement = self._get_readable_benefit_grant_statement(auth_subject) + + if type is not None: + statement = statement.where(Benefit.type.in_(type)) + + if benefit_id is not None: + statement = statement.where(BenefitGrant.benefit_id.in_(benefit_id)) + + if organization_id is not None: + statement = statement.where(Benefit.organization_id.in_(organization_id)) + + if order_id is not None: + statement = statement.where(BenefitGrant.order_id.in_(order_id)) + + if subscription_id is not None: + statement = statement.where( + BenefitGrant.subscription_id.in_(subscription_id) + ) + + order_by_clauses: list[UnaryExpression[Any]] = [] + for criterion, is_desc in sorting: + clause_function = desc if is_desc else asc + if criterion == CustomerBenefitGrantSortProperty.granted_at: + order_by_clauses.append(clause_function(BenefitGrant.granted_at)) + elif criterion == CustomerBenefitGrantSortProperty.type: + order_by_clauses.append(clause_function(Benefit.type)) + elif criterion == CustomerBenefitGrantSortProperty.organization: + order_by_clauses.append(clause_function(Organization.slug)) + statement = statement.order_by(*order_by_clauses) + + return await paginate(session, statement, pagination=pagination) + + async def get_by_id( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Customer], + id: uuid.UUID, + ) -> BenefitGrant | None: + statement = self._get_readable_benefit_grant_statement(auth_subject).where( + BenefitGrant.id == id + ) + + result = await session.execute(statement) + return result.scalar_one_or_none() + + def _get_readable_benefit_grant_statement( + self, auth_subject: AuthSubject[User | Customer] + ) -> Select[tuple[BenefitGrant]]: + statement = ( + select(BenefitGrant) + .join(Benefit, onclause=Benefit.id == BenefitGrant.benefit_id) + .join(Organization, onclause=Benefit.organization_id == Organization.id) + .where( + BenefitGrant.deleted_at.is_(None), + ) + .options( + contains_eager(BenefitGrant.benefit).options( + contains_eager(Benefit.organization) + ), + ) + ) + + if is_user(auth_subject): + raise NotImplementedError("TODO") + elif is_customer(auth_subject): + statement = statement.where( + BenefitGrant.customer_id == auth_subject.subject.id + ) + + return statement + + +customer_benefit_grant = CustomerBenefitGrantService(BenefitGrant) diff --git a/server/polar/user/service/downloadables.py b/server/polar/customer_portal/service/downloadables.py similarity index 90% rename from server/polar/user/service/downloadables.py rename to server/polar/customer_portal/service/downloadables.py index f5cc2110fd..ba26d4bad3 100644 --- a/server/polar/user/service/downloadables.py +++ b/server/polar/customer_portal/service/downloadables.py @@ -6,6 +6,7 @@ from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from sqlalchemy.orm import contains_eager +from polar.auth.models import AuthSubject, is_customer, is_user from polar.config import settings from polar.exceptions import ( BadRequest, @@ -42,13 +43,13 @@ class DownloadableService( async def get_list( self, session: AsyncSession, + auth_subject: AuthSubject[User | Customer], *, - user: User, pagination: PaginationParams, organization_id: Sequence[UUID] | None = None, benefit_id: Sequence[UUID] | None = None, ) -> tuple[Sequence[Downloadable], int]: - statement = self._get_base_query(user) + statement = self._get_base_query(auth_subject) if organization_id: statement = statement.where(File.organization_id.in_(organization_id)) @@ -185,11 +186,14 @@ def create_download_token(self, downloadable: Downloadable) -> DownloadableURL: last_downloaded_at=last_downloaded_at, ) ) - redirect_to = f"{settings.BASE_URL}/users/downloadables/{token}" + redirect_to = f"{settings.BASE_URL}/customer-portal/downloadables/{token}" return DownloadableURL(url=redirect_to, expires_at=expires_at) async def get_from_token_or_raise( - self, session: AsyncSession, user: User, token: str + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Customer], + token: str, ) -> Downloadable: try: unpacked = token_serializer.loads( @@ -203,7 +207,7 @@ async def get_from_token_or_raise( except KeyError: raise BadRequest() - statement = self._get_base_query(user).where(Downloadable.id == id) + statement = self._get_base_query(auth_subject).where(Downloadable.id == id) res = await session.execute(statement) downloadable = res.scalars().one_or_none() if not downloadable: @@ -220,14 +224,15 @@ def generate_download_schema(self, downloadable: Downloadable) -> DownloadableRe file=file_schema, ) - def _get_base_query(self, user: User) -> sql.Select[tuple[Downloadable]]: + def _get_base_query( + self, auth_subject: AuthSubject[User | Customer] + ) -> sql.Select[tuple[Downloadable]]: statement = ( sql.select(Downloadable) .join(File) .join(Benefit) .options(contains_eager(Downloadable.file)) .where( - Downloadable.user_id == user.id, Downloadable.status == DownloadableStatus.granted, Downloadable.deleted_at.is_(None), File.deleted_at.is_(None), @@ -237,6 +242,14 @@ def _get_base_query(self, user: User) -> sql.Select[tuple[Downloadable]]: ) .order_by(Downloadable.created_at.desc()) ) + + if is_user(auth_subject): + raise NotImplementedError("TODO") + elif is_customer(auth_subject): + statement = statement.where( + Downloadable.customer_id == auth_subject.subject.id + ) + return statement diff --git a/server/polar/user/service/order.py b/server/polar/customer_portal/service/order.py similarity index 78% rename from server/polar/user/service/order.py rename to server/polar/customer_portal/service/order.py index 9167ef72c6..c390ae87ab 100644 --- a/server/polar/user/service/order.py +++ b/server/polar/customer_portal/service/order.py @@ -6,28 +6,28 @@ from sqlalchemy import Select, UnaryExpression, asc, desc, or_, select from sqlalchemy.orm import aliased, contains_eager, joinedload, selectinload -from polar.auth.models import AuthSubject +from polar.auth.models import AuthSubject, is_customer, is_user from polar.exceptions import PolarError from polar.integrations.stripe.service import stripe as stripe_service from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.kit.sorting import Sorting -from polar.models import Order, Organization, Product, ProductPrice, User +from polar.models import Customer, Order, Organization, Product, ProductPrice, User from polar.models.product_price import ProductPriceType -class UserOrderError(PolarError): ... +class CustomerOrderError(PolarError): ... -class InvoiceNotAvailable(UserOrderError): +class InvoiceNotAvailable(CustomerOrderError): def __init__(self, order: Order) -> None: self.order = order message = "The invoice is not available for this order." super().__init__(message, 404) -class UserOrderSortProperty(StrEnum): +class CustomerOrderSortProperty(StrEnum): created_at = "created_at" amount = "amount" organization = "organization" @@ -35,11 +35,11 @@ class UserOrderSortProperty(StrEnum): subscription = "subscription" -class UserOrderService(ResourceServiceReader[Order]): +class CustomerOrderService(ResourceServiceReader[Order]): async def list( self, session: AsyncSession, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[User | Customer], *, organization_id: Sequence[uuid.UUID] | None = None, product_id: Sequence[uuid.UUID] | None = None, @@ -47,8 +47,8 @@ async def list( subscription_id: Sequence[uuid.UUID] | None = None, query: str | None = None, pagination: PaginationParams, - sorting: list[Sorting[UserOrderSortProperty]] = [ - (UserOrderSortProperty.created_at, True) + sorting: list[Sorting[CustomerOrderSortProperty]] = [ + (CustomerOrderSortProperty.created_at, True) ], ) -> tuple[Sequence[Order], int]: statement = self._get_readable_order_statement(auth_subject) @@ -91,15 +91,15 @@ async def list( order_by_clauses: list[UnaryExpression[Any]] = [] for criterion, is_desc in sorting: clause_function = desc if is_desc else asc - if criterion == UserOrderSortProperty.created_at: + if criterion == CustomerOrderSortProperty.created_at: order_by_clauses.append(clause_function(Order.created_at)) - elif criterion == UserOrderSortProperty.amount: + elif criterion == CustomerOrderSortProperty.amount: order_by_clauses.append(clause_function(Order.amount)) - elif criterion == UserOrderSortProperty.organization: + elif criterion == CustomerOrderSortProperty.organization: order_by_clauses.append(clause_function(Organization.slug)) - elif criterion == UserOrderSortProperty.product: + elif criterion == CustomerOrderSortProperty.product: order_by_clauses.append(clause_function(Product.name)) - elif criterion == UserOrderSortProperty.subscription: + elif criterion == CustomerOrderSortProperty.subscription: order_by_clauses.append(clause_function(Order.subscription_id)) statement = statement.order_by(*order_by_clauses) @@ -108,7 +108,7 @@ async def list( async def get_by_id( self, session: AsyncSession, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[User | Customer], id: uuid.UUID, ) -> Order | None: statement = ( @@ -139,15 +139,22 @@ async def get_order_invoice_url(self, order: Order) -> str: return stripe_invoice.hosted_invoice_url def _get_readable_order_statement( - self, auth_subject: AuthSubject[User] + self, auth_subject: AuthSubject[User | Customer] ) -> Select[tuple[Order]]: statement = ( select(Order) - .where(Order.deleted_at.is_(None), Order.user_id == auth_subject.subject.id) + .where(Order.deleted_at.is_(None)) .join(Order.product) .options(contains_eager(Order.product)) ) + + if is_user(auth_subject): + raise NotImplementedError("TODO") + elif is_customer(auth_subject): + customer = auth_subject.subject + statement = statement.where(Order.customer_id == customer.id) + return statement -user_order = UserOrderService(Order) +customer_order = CustomerOrderService(Order) diff --git a/server/polar/user/service/subscription.py b/server/polar/customer_portal/service/subscription.py similarity index 87% rename from server/polar/user/service/subscription.py rename to server/polar/customer_portal/service/subscription.py index 52442aa691..43aca92c73 100644 --- a/server/polar/user/service/subscription.py +++ b/server/polar/customer_portal/service/subscription.py @@ -6,7 +6,7 @@ from sqlalchemy import Select, UnaryExpression, asc, desc, nulls_first, or_, select from sqlalchemy.orm import aliased, contains_eager, joinedload, selectinload -from polar.auth.models import AuthSubject +from polar.auth.models import AuthSubject, is_customer, is_user from polar.exceptions import PolarError, PolarRequestValidationError from polar.integrations.stripe.service import stripe as stripe_service from polar.kit.db.postgres import AsyncSession @@ -15,6 +15,7 @@ from polar.kit.sorting import Sorting from polar.kit.utils import utc_now from polar.models import ( + Customer, Organization, Product, ProductPrice, @@ -30,13 +31,13 @@ from polar.product.service.product_price import product_price as product_price_service from polar.subscription.service import subscription as subscription_service -from ..schemas.subscription import UserSubscriptionUpdate +from ..schemas.subscription import CustomerSubscriptionUpdate -class UserSubscriptionError(PolarError): ... +class CustomerSubscriptionError(PolarError): ... -class AlreadyCanceledSubscription(UserSubscriptionError): +class AlreadyCanceledSubscription(CustomerSubscriptionError): def __init__(self, subscription: Subscription) -> None: self.subscription = subscription message = ( @@ -45,14 +46,14 @@ def __init__(self, subscription: Subscription) -> None: super().__init__(message, 403) -class SubscriptionNotActiveOnStripe(UserSubscriptionError): +class SubscriptionNotActiveOnStripe(CustomerSubscriptionError): def __init__(self, subscription: Subscription) -> None: self.subscription = subscription message = "This subscription is not active on Stripe." super().__init__(message, 400) -class UserSubscriptionSortProperty(StrEnum): +class CustomerSubscriptionSortProperty(StrEnum): started_at = "started_at" amount = "amount" status = "status" @@ -60,19 +61,19 @@ class UserSubscriptionSortProperty(StrEnum): product = "product" -class UserSubscriptionService(ResourceServiceReader[Subscription]): +class CustomerSubscriptionService(ResourceServiceReader[Subscription]): async def list( self, session: AsyncSession, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[User | Customer], *, organization_id: Sequence[uuid.UUID] | None = None, product_id: Sequence[uuid.UUID] | None = None, active: bool | None = None, query: str | None = None, pagination: PaginationParams, - sorting: list[Sorting[UserSubscriptionSortProperty]] = [ - (UserSubscriptionSortProperty.started_at, True) + sorting: list[Sorting[CustomerSubscriptionSortProperty]] = [ + (CustomerSubscriptionSortProperty.started_at, True) ], ) -> tuple[Sequence[Subscription], int]: statement = self._get_readable_subscription_statement(auth_subject).where( @@ -120,17 +121,17 @@ async def list( order_by_clauses: list[UnaryExpression[Any]] = [] for criterion, is_desc in sorting: clause_function = desc if is_desc else asc - if criterion == UserSubscriptionSortProperty.started_at: + if criterion == CustomerSubscriptionSortProperty.started_at: order_by_clauses.append(clause_function(Subscription.started_at)) - elif criterion == UserSubscriptionSortProperty.amount: + elif criterion == CustomerSubscriptionSortProperty.amount: order_by_clauses.append( nulls_first(clause_function(Subscription.amount)) ) - elif criterion == UserSubscriptionSortProperty.status: + elif criterion == CustomerSubscriptionSortProperty.status: order_by_clauses.append(clause_function(Subscription.status)) - elif criterion == UserSubscriptionSortProperty.organization: + elif criterion == CustomerSubscriptionSortProperty.organization: order_by_clauses.append(clause_function(Organization.slug)) - elif criterion == UserSubscriptionSortProperty.product: + elif criterion == CustomerSubscriptionSortProperty.product: order_by_clauses.append(clause_function(Product.name)) statement = statement.order_by(*order_by_clauses) @@ -139,7 +140,7 @@ async def list( async def get_by_id( self, session: AsyncSession, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[User | Customer], id: uuid.UUID, ) -> Subscription | None: statement = ( @@ -162,7 +163,7 @@ async def update( session: AsyncSession, *, subscription: Subscription, - subscription_update: UserSubscriptionUpdate, + subscription_update: CustomerSubscriptionUpdate, ) -> Subscription: price = await product_price_service.get_by_id( session, subscription_update.product_price_id @@ -299,13 +300,18 @@ async def cancel( return subscription def _get_readable_subscription_statement( - self, auth_subject: AuthSubject[User] + self, auth_subject: AuthSubject[User | Customer] ) -> Select[tuple[Subscription]]: - statement = select(Subscription).where( - Subscription.deleted_at.is_(None), - Subscription.user_id == auth_subject.subject.id, - ) + statement = select(Subscription).where(Subscription.deleted_at.is_(None)) + + if is_user(auth_subject): + raise NotImplementedError("TODO") + elif is_customer(auth_subject): + statement = statement.where( + Subscription.customer_id == auth_subject.subject.id + ) + return statement -user_subscription = UserSubscriptionService(Subscription) +customer_subscription = CustomerSubscriptionService(Subscription) diff --git a/server/polar/license_key/schemas.py b/server/polar/license_key/schemas.py index db70e4fcee..4297263b85 100644 --- a/server/polar/license_key/schemas.py +++ b/server/polar/license_key/schemas.py @@ -6,7 +6,10 @@ from polar.benefit.schemas import BenefitID from polar.exceptions import ResourceNotFound, Unauthorized -from polar.kit.schemas import Schema +from polar.kit.address import Address +from polar.kit.metadata import MetadataOutputMixin +from polar.kit.schemas import IDSchema, Schema, TimestampedSchema +from polar.kit.tax import TaxID from polar.kit.utils import generate_uuid, utc_now from polar.models.benefit import ( BenefitLicenseKeyActivationProperties, @@ -39,7 +42,7 @@ class LicenseKeyValidate(Schema): organization_id: UUID4 activation_id: UUID4 | None = None benefit_id: BenefitID | None = None - user_id: UUID4 | None = None + customer_id: UUID4 | None = None increment_usage: int | None = None conditions: dict[str, Any] = {} @@ -58,18 +61,26 @@ class LicenseKeyDeactivate(Schema): activation_id: UUID4 -class LicenseKeyUser(Schema): - id: UUID4 - public_name: str +class LicenseKeyCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): email: str - avatar_url: str | None + email_verified: bool + name: str | None + billing_address: Address | None + tax_id: TaxID | None + organization_id: UUID4 class LicenseKeyRead(Schema): id: UUID4 organization_id: UUID4 - user_id: UUID4 - user: LicenseKeyUser + user_id: UUID4 = Field( + validation_alias="customer_id", deprecated="Use `customer_id`." + ) + customer_id: UUID4 + user: LicenseKeyCustomer = Field( + validation_alias="customer", deprecated="Use `customer`." + ) + customer: LicenseKeyCustomer benefit_id: BenefitID key: str display_key: str diff --git a/server/polar/license_key/service.py b/server/polar/license_key/service.py index 5983939199..39eea3e81f 100644 --- a/server/polar/license_key/service.py +++ b/server/polar/license_key/service.py @@ -5,7 +5,7 @@ from sqlalchemy import Select, and_, func, select from sqlalchemy.orm import contains_eager, joinedload -from polar.auth.models import AuthSubject, is_organization, is_user +from polar.auth.models import AuthSubject, is_customer, is_organization, is_user from polar.exceptions import BadRequest, NotPermitted, ResourceNotFound from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceService @@ -166,28 +166,6 @@ async def get_list( return await paginate(session, query, pagination=pagination) - async def get_user_list( - self, - session: AsyncSession, - *, - user: User, - pagination: PaginationParams, - benefit_id: UUID | None = None, - organization_ids: Sequence[UUID] | None = None, - ) -> tuple[Sequence[LicenseKey], int]: - query = ( - self._get_select_base() - .where(LicenseKey.user_id == user.id) - .order_by(LicenseKey.created_at.asc()) - ) - if organization_ids: - query = query.where(LicenseKey.organization_id.in_(organization_ids)) - - if benefit_id: - query = query.where(LicenseKey.benefit_id == benefit_id) - - return await paginate(session, query, pagination=pagination) - async def update( self, session: AsyncSession, @@ -210,25 +188,19 @@ async def validate( license_key: LicenseKey, validate: LicenseKeyValidate, ) -> tuple[LicenseKey, LicenseKeyActivation | None]: + bound_logger = log.bind( + license_key_id=license_key.id, + organization_id=license_key.organization_id, + customer_id=license_key.customer_id, + benefit_id=license_key.benefit_id, + ) if not license_key.is_active(): - log.info( - "license_key.validate.invalid_status", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - ) + bound_logger.info("license_key.validate.invalid_status") raise ResourceNotFound("License key is no longer active.") if license_key.expires_at: if utc_now() >= license_key.expires_at: - log.info( - "license_key.validate.invalid_ttl", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - ) + bound_logger.info("license_key.validate.invalid_ttl") raise ResourceNotFound("License key has expired.") activation = None @@ -240,46 +212,25 @@ async def validate( ) if activation.conditions and validate.conditions != activation.conditions: # Skip logging UGC conditions - log.info( - "license_key.validate.invalid_conditions", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - ) + bound_logger.info("license_key.validate.invalid_conditions") raise ResourceNotFound("License key does not match required conditions") if validate.benefit_id and validate.benefit_id != license_key.benefit_id: - log.info( - "license_key.validate.invalid_benefit", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - validate_benefit_id=validate.benefit_id, - ) + bound_logger.info("license_key.validate.invalid_benefit") raise ResourceNotFound("License key does not match given benefit.") - if validate.user_id and validate.user_id != license_key.user_id: - log.warn( + if validate.customer_id and validate.customer_id != license_key.customer_id: + bound_logger.warn( "license_key.validate.invalid_owner", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - validate_user_id=validate.user_id, + validate_customer_id=validate.customer_id, ) raise ResourceNotFound("License key does not match given user.") if validate.increment_usage and license_key.limit_usage: remaining = license_key.limit_usage - license_key.usage if validate.increment_usage > remaining: - log.info( + bound_logger.info( "license_key.validate.insufficient_usage", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, usage_remaining=remaining, usage_requested=validate.increment_usage, ) @@ -287,13 +238,7 @@ async def validate( license_key.mark_validated(increment_usage=validate.increment_usage) session.add(license_key) - log.info( - "license_key.validate", - license_key_id=license_key.id, - organization_id=license_key.organization_id, - user=license_key.user_id, - benefit_id=license_key.benefit_id, - ) + bound_logger.info("license_key.validate") return (license_key, activation) async def get_activation_count( @@ -378,7 +323,7 @@ async def deactivate( ) return True - async def user_grant( + async def customer_grant( self, session: AsyncSession, *, @@ -403,18 +348,18 @@ async def user_grant( benefit_id=benefit.id, ) if license_key_id: - return await self.user_update_grant( + return await self.customer_update_grant( session, create_schema=create_schema, license_key_id=license_key_id, ) - return await self.user_create_grant( + return await self.customer_create_grant( session, create_schema=create_schema, ) - async def user_update_grant( + async def customer_update_grant( self, session: AsyncSession, *, @@ -453,7 +398,7 @@ async def user_update_grant( ) return key - async def user_create_grant( + async def customer_create_grant( self, session: AsyncSession, *, @@ -472,7 +417,7 @@ async def user_create_grant( ) return key - async def user_revoke( + async def customer_revoke( self, session: AsyncSession, customer: Customer, @@ -498,6 +443,39 @@ async def user_revoke( ) return key + async def get_customer_list( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Customer], + *, + pagination: PaginationParams, + benefit_id: UUID | None = None, + organization_ids: Sequence[UUID] | None = None, + ) -> tuple[Sequence[LicenseKey], int]: + query = self._get_select_customer_base(auth_subject).order_by( + LicenseKey.created_at.asc() + ) + + if organization_ids: + query = query.where(LicenseKey.organization_id.in_(organization_ids)) + + if benefit_id: + query = query.where(LicenseKey.benefit_id == benefit_id) + + return await paginate(session, query, pagination=pagination) + + async def get_customer_license_key( + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Customer], + license_key_id: UUID, + ) -> LicenseKey | None: + query = self._get_select_customer_base(auth_subject).where( + LicenseKey.id == license_key_id + ) + result = await session.execute(query) + return result.unique().scalar_one_or_none() + def _get_select_base(self) -> Select[tuple[LicenseKey]]: return ( select(LicenseKey) @@ -505,5 +483,15 @@ def _get_select_base(self) -> Select[tuple[LicenseKey]]: .where(LicenseKey.deleted_at.is_(None)) ) + def _get_select_customer_base( + self, auth_subject: AuthSubject[User | Customer] + ) -> Select[tuple[LicenseKey]]: + query = self._get_select_base() + if is_user(auth_subject): + raise NotImplementedError("TODO") + elif is_customer(auth_subject): + query = query.where(LicenseKey.customer_id == auth_subject.subject.id) + return query + license_key = LicenseKeyService(LicenseKey) diff --git a/server/polar/models/benefit.py b/server/polar/models/benefit.py index 0ee9934f11..6de7ca8e5c 100644 --- a/server/polar/models/benefit.py +++ b/server/polar/models/benefit.py @@ -83,7 +83,7 @@ class BenefitLicenseKeyExpirationProperties(TypedDict): class BenefitLicenseKeyActivationProperties(TypedDict): limit: int - enable_user_admin: bool + enable_customer_admin: bool class BenefitLicenseKeysProperties(BenefitProperties): diff --git a/server/polar/user/auth.py b/server/polar/user/auth.py index 062279d676..3deb66cfdf 100644 --- a/server/polar/user/auth.py +++ b/server/polar/user/auth.py @@ -3,41 +3,9 @@ from fastapi import Depends from polar.auth.dependencies import Authenticator -from polar.auth.models import Anonymous, AuthSubject, User +from polar.auth.models import AuthSubject, User from polar.auth.scope import Scope -_UserBenefitsRead = Authenticator( - required_scopes={Scope.web_default, Scope.user_benefits_read}, - allowed_subjects={User}, -) -UserBenefitsRead = Annotated[AuthSubject[User], Depends(_UserBenefitsRead)] - -_UserOrdersRead = Authenticator( - required_scopes={Scope.web_default, Scope.user_orders_read}, - allowed_subjects={User}, -) -UserOrdersRead = Annotated[AuthSubject[User], Depends(_UserOrdersRead)] - -_UserSubscriptionsRead = Authenticator( - required_scopes={Scope.web_default, Scope.user_subscriptions_read}, - allowed_subjects={User}, -) -UserSubscriptionsRead = Annotated[AuthSubject[User], Depends(_UserSubscriptionsRead)] - -_UserSubscriptionsWriteOrAnonymous = Authenticator( - required_scopes={Scope.web_default, Scope.user_subscriptions_write}, - allowed_subjects={Anonymous, User}, -) -UserSubscriptionsWriteOrAnonymous = Annotated[ - AuthSubject[Anonymous | User], Depends(_UserSubscriptionsWriteOrAnonymous) -] - -_UserSubscriptionsWrite = Authenticator( - required_scopes={Scope.web_default, Scope.user_subscriptions_write}, - allowed_subjects={User}, -) -UserSubscriptionsWrite = Annotated[AuthSubject[User], Depends(_UserSubscriptionsWrite)] - _UserAdvertisementCampaignsRead = Authenticator( required_scopes={Scope.web_default, Scope.user_advertisement_campaigns_read}, allowed_subjects={User}, @@ -53,29 +21,3 @@ UserAdvertisementCampaignsWrite = Annotated[ AuthSubject[User], Depends(_UserAdvertisementCampaignsWrite) ] - -UserDownloadablesRead = Annotated[ - AuthSubject[User], - Depends( - Authenticator( - required_scopes={ - Scope.web_default, - Scope.user_downloadables_read, - }, - allowed_subjects={User}, - ) - ), -] - -UserLicenseKeysRead = Annotated[ - AuthSubject[User], - Depends( - Authenticator( - required_scopes={ - Scope.web_default, - Scope.user_license_keys_read, - }, - allowed_subjects={User}, - ) - ), -] diff --git a/server/polar/user/endpoints/__init__.py b/server/polar/user/endpoints/__init__.py index 4db0579e57..439f993deb 100644 --- a/server/polar/user/endpoints/__init__.py +++ b/server/polar/user/endpoints/__init__.py @@ -1,19 +1,19 @@ +from polar.customer_portal.endpoints.downloadables import router as downloadables_router +from polar.customer_portal.endpoints.license_keys import router as license_keys_router +from polar.customer_portal.endpoints.order import router as order_router +from polar.customer_portal.endpoints.subscription import router as subscription_router from polar.routing import APIRouter from .advertisement import router as advertisement_router -from .benefit import router as benefit_router -from .downloadables import router as downloadables_router -from .license_keys import router as license_keys_router -from .order import router as order_router -from .subscription import router as subscription_router from .user import router as user_router router = APIRouter(prefix="/users", tags=["users"]) router.include_router(user_router) -router.include_router(benefit_router) -router.include_router(order_router) -router.include_router(subscription_router) router.include_router(advertisement_router) -router.include_router(downloadables_router) -router.include_router(license_keys_router) + +# Include customer portal endpoints for backwards compatibility +router.include_router(order_router, deprecated=True, include_in_schema=False) +router.include_router(subscription_router, deprecated=True, include_in_schema=False) +router.include_router(downloadables_router, deprecated=True, include_in_schema=False) +router.include_router(license_keys_router, deprecated=True, include_in_schema=False) diff --git a/server/polar/user/schemas/benefit.py b/server/polar/user/schemas/benefit.py deleted file mode 100644 index d446dcfef8..0000000000 --- a/server/polar/user/schemas/benefit.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Annotated - -from pydantic import TypeAdapter - -from polar.benefit.schemas import BenefitSubscriber, BenefitSubscriberAdapter -from polar.kit.schemas import ClassName, MergeJSONSchema - -UserBenefit = Annotated[ - BenefitSubscriber, - MergeJSONSchema({"title": "UserBenefit"}), - ClassName("UserBenefit"), -] -UserBenefitAdapter: TypeAdapter[UserBenefit] = BenefitSubscriberAdapter diff --git a/server/polar/user/service/benefit.py b/server/polar/user/service/benefit.py deleted file mode 100644 index 308cddd46c..0000000000 --- a/server/polar/user/service/benefit.py +++ /dev/null @@ -1,148 +0,0 @@ -import uuid -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from sqlalchemy import Select, UnaryExpression, asc, desc, select -from sqlalchemy.orm import joinedload, selectinload - -from polar.auth.models import AuthSubject -from polar.exceptions import PolarError -from polar.kit.db.postgres import AsyncSession -from polar.kit.pagination import PaginationParams, paginate -from polar.kit.services import ResourceServiceReader -from polar.kit.sorting import Sorting -from polar.models import ( - Benefit, - BenefitGrant, - Organization, - User, -) -from polar.models.benefit import BenefitType - - -class UserBenefitError(PolarError): ... - - -class UserBenefitSortProperty(StrEnum): - granted_at = "granted_at" - type = "type" - organization = "organization" - - -class UserBenefitService(ResourceServiceReader[Benefit]): - async def list( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - *, - type: Sequence[BenefitType] | None = None, - organization_id: Sequence[uuid.UUID] | None = None, - order_id: Sequence[uuid.UUID] | None = None, - subscription_id: Sequence[uuid.UUID] | None = None, - pagination: PaginationParams, - sorting: list[Sorting[UserBenefitSortProperty]] = [ - (UserBenefitSortProperty.granted_at, True) - ], - ) -> tuple[Sequence[Benefit], int]: - statement = self._get_readable_benefit_statement(auth_subject).options( - joinedload(Benefit.organization) - ) - - if type is not None: - statement = statement.where(Benefit.type.in_(type)) - - if organization_id is not None: - statement = statement.where(Benefit.organization_id.in_(organization_id)) - - if order_id is not None: - statement = statement.where( - Benefit.id.in_( - select(BenefitGrant.benefit_id).where( - BenefitGrant.order_id.in_(order_id) - ) - ) - ) - - if subscription_id is not None: - statement = statement.where( - Benefit.id.in_( - select(BenefitGrant.benefit_id).where( - BenefitGrant.subscription_id.in_(subscription_id) - ) - ) - ) - - order_by_clauses: list[UnaryExpression[Any]] = [] - for criterion, is_desc in sorting: - clause_function = desc if is_desc else asc - if criterion == UserBenefitSortProperty.granted_at: - # Join only the most recent/oldest grant - statement = statement.join( - BenefitGrant, - onclause=BenefitGrant.id - == select(BenefitGrant) - .correlate(Benefit) - .with_only_columns(BenefitGrant.id) - .where( - BenefitGrant.benefit_id == Benefit.id, - BenefitGrant.is_granted.is_(True), - ) - .order_by(clause_function(BenefitGrant.granted_at)) - .limit(1) - .scalar_subquery(), - ) - order_by_clauses.append(clause_function(BenefitGrant.granted_at)) - elif criterion == UserBenefitSortProperty.type: - order_by_clauses.append(clause_function(Benefit.type)) - elif criterion == UserBenefitSortProperty.organization: - statement = statement.join( - Organization, onclause=Benefit.organization_id == Organization.id - ) - order_by_clauses.append(clause_function(Organization.slug)) - statement = statement.order_by(*order_by_clauses) - - return await paginate(session, statement, pagination=pagination) - - async def get_by_id( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - id: uuid.UUID, - ) -> Benefit | None: - statement = ( - self._get_readable_benefit_statement(auth_subject) - .where(Benefit.id == id) - .options(joinedload(Benefit.organization)) - ) - - result = await session.execute(statement) - return result.scalar_one_or_none() - - def _get_readable_benefit_statement( - self, auth_subject: AuthSubject[User] - ) -> Select[tuple[Benefit]]: - statement = ( - select(Benefit) - .where( - Benefit.deleted_at.is_(None), - Benefit.id.in_( - select(BenefitGrant.benefit_id).where( - BenefitGrant.user_id == auth_subject.subject.id, - BenefitGrant.is_granted.is_(True), - ) - ), - ) - .options( - selectinload( - Benefit.grants.and_( - BenefitGrant.user_id == auth_subject.subject.id, - BenefitGrant.is_granted.is_(True), - ) - ) - ) - ) - return statement - - -user_benefit = UserBenefitService(Benefit) diff --git a/server/tests/customer_portal/__init__.py b/server/tests/customer_portal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/user/endpoints/test_benefits.py b/server/tests/customer_portal/endpoints/test_benefit_grant.py similarity index 68% rename from server/tests/user/endpoints/test_benefits.py rename to server/tests/customer_portal/endpoints/test_benefit_grant.py index f7fe10ba06..e941487ea3 100644 --- a/server/tests/user/endpoints/test_benefits.py +++ b/server/tests/customer_portal/endpoints/test_benefit_grant.py @@ -1,27 +1,29 @@ import pytest from httpx import AsyncClient -from polar.models import Benefit, Subscription, User +from polar.models import Benefit, Customer, Subscription +from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_benefit_grant @pytest.mark.asyncio @pytest.mark.skip_db_asserts -class TestListBenefits: - @pytest.mark.auth - async def test_user( +class TestListBenefitGrants: + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer( self, client: AsyncClient, save_fixture: SaveFixture, subscription: Subscription, benefit_organization: Benefit, benefit_organization_second: Benefit, - user: User, + customer: Customer, + customer_second: Customer, ) -> None: await create_benefit_grant( save_fixture, - user, + customer, benefit_organization, granted=True, subscription=subscription, @@ -29,13 +31,13 @@ async def test_user( await create_benefit_grant( save_fixture, - user, + customer_second, benefit_organization_second, granted=False, subscription=subscription, ) - response = await client.get("/v1/users/benefits/") + response = await client.get("/v1/customer-portal/benefit-grants/") assert response.status_code == 200 json = response.json() diff --git a/server/tests/user/endpoints/test_downloadables.py b/server/tests/customer_portal/endpoints/test_downloadables.py similarity index 82% rename from server/tests/user/endpoints/test_downloadables.py rename to server/tests/customer_portal/endpoints/test_downloadables.py index 79925866b2..58d8aa8616 100644 --- a/server/tests/user/endpoints/test_downloadables.py +++ b/server/tests/customer_portal/endpoints/test_downloadables.py @@ -8,11 +8,11 @@ from httpx import AsyncClient from polar.benefit.schemas import BenefitDownloadablesCreateProperties -from polar.file.schemas import FileRead -from polar.models import File, Organization, Product, User +from polar.customer_portal.schemas.downloadables import DownloadableRead +from polar.models import Customer, File, Organization, Product from polar.postgres import AsyncSession, sql from polar.redis import Redis -from polar.user.schemas.downloadables import DownloadableRead +from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.downloadable import TestDownloadable @@ -20,39 +20,31 @@ @pytest.mark.asyncio @pytest.mark.http_auto_expunge class TestDownloadablesEndpoints: - async def test_anonymous_list_401s( - self, - session: AsyncSession, - client: AsyncClient, - ) -> None: - response = await client.get("/v1/users/downloadables/") + async def test_anonymous_list_401s(self, client: AsyncClient) -> None: + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 401 - async def test_anonymous_download_401s( - self, - session: AsyncSession, - client: AsyncClient, - ) -> None: - response = await client.get("/v1/users/downloadables/i-am-hacker") + async def test_anonymous_download_401s(self, client: AsyncClient) -> None: + response = await client.get("/v1/customer-portal/downloadables/i-am-hacker") assert response.status_code == 401 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_revoked_404s( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: benefit, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -61,7 +53,7 @@ async def test_revoked_404s( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -72,29 +64,29 @@ async def test_revoked_404s( polar_download_url = downloadable["file"]["download"]["url"] # Revoke the benefit - await TestDownloadable.run_revoke_task(session, redis, benefit, user) + await TestDownloadable.run_revoke_task(session, redis, benefit, customer) # Polar download endpoint will now 404 response = await client.get(polar_download_url, follow_redirects=False) assert response.status_code == 404 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_wrong_token_404s( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: benefit, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -103,7 +95,7 @@ async def test_wrong_token_404s( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -113,31 +105,31 @@ async def test_wrong_token_404s( downloadable = downloadable_list[0] # Revoke the benefit - await TestDownloadable.run_revoke_task(session, redis, benefit, user) + await TestDownloadable.run_revoke_task(session, redis, benefit, customer) # Polar download endpoint will now 404 response = await client.get( - "/v1/users/downloadables/i-am-a-hacker", follow_redirects=False + "/v1/customer-portal/downloadables/i-am-a-hacker", follow_redirects=False ) assert response.status_code == 404 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_expired_token_410s( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: benefit, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -146,7 +138,7 @@ async def test_expired_token_410s( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -167,23 +159,23 @@ async def test_expired_token_410s( response = await client.get(polar_download_url, follow_redirects=False) assert response.status_code == 410 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_signatureless_url_403s( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: _, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -192,7 +184,7 @@ async def test_signatureless_url_403s( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -215,23 +207,23 @@ async def test_signatureless_url_403s( assert response.status_code == 403 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_polar_disabled_file_vanishes( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: _, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -240,7 +232,7 @@ async def test_polar_disabled_file_vanishes( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -258,7 +250,7 @@ async def test_polar_disabled_file_vanishes( ) await session.execute(statement) - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] @@ -266,23 +258,23 @@ async def test_polar_disabled_file_vanishes( assert pagination["total_count"] == 0 assert len(downloadable_list) == 0 - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_download( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, - uploaded_logo_jpg: FileRead, + uploaded_logo_jpg: File, ) -> None: _, granted = await TestDownloadable.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitDownloadablesCreateProperties( @@ -291,7 +283,7 @@ async def test_download( ) # List of downloadables - response = await client.get("/v1/users/downloadables/") + response = await client.get("/v1/customer-portal/downloadables/") assert response.status_code == 200 data = response.json() downloadable_list = data["items"] diff --git a/server/tests/user/endpoints/test_license_keys.py b/server/tests/customer_portal/endpoints/test_license_keys.py similarity index 87% rename from server/tests/user/endpoints/test_license_keys.py rename to server/tests/customer_portal/endpoints/test_license_keys.py index 72b51e753c..c4425a032c 100644 --- a/server/tests/user/endpoints/test_license_keys.py +++ b/server/tests/customer_portal/endpoints/test_license_keys.py @@ -12,7 +12,7 @@ ) from polar.kit.utils import generate_uuid, utc_now from polar.license_key.service import license_key as license_key_service -from polar.models import Organization, Product, User +from polar.models import Customer, Organization, Product from polar.postgres import AsyncSession from polar.redis import Redis from tests.fixtures.database import SaveFixture @@ -21,14 +21,14 @@ @pytest.mark.asyncio @pytest.mark.http_auto_expunge -class TestUserLicenseKeyEndpoints: +class TestCustomerLicenseKeyEndpoints: async def test_validate( self, session: AsyncSession, redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -36,7 +36,7 @@ async def test_validate( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -48,7 +48,7 @@ async def test_validate( assert lk key_only_response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -59,7 +59,7 @@ async def test_validate( assert data.get("validations") == 1 scope_benefit_404_response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -69,7 +69,7 @@ async def test_validate( assert scope_benefit_404_response.status_code == 404 scope_benefit_response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -81,12 +81,12 @@ async def test_validate( assert data.get("validations") == 2 full_response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), "benefit_id": str(lk.benefit_id), - "user_id": str(lk.user_id), + "customer_id": str(lk.customer_id), }, ) assert full_response.status_code == 200 @@ -101,7 +101,7 @@ async def test_validate_expiration( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -110,7 +110,7 @@ async def test_validate_expiration( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -122,7 +122,7 @@ async def test_validate_expiration( assert lk response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -131,7 +131,7 @@ async def test_validate_expiration( assert response.status_code == 200 with freeze_time(now + relativedelta(years=10)): response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -143,7 +143,7 @@ async def test_validate_expiration( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -156,7 +156,7 @@ async def test_validate_expiration( assert day_lk response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": day_lk.key, "organization_id": str(organization.id), @@ -165,7 +165,7 @@ async def test_validate_expiration( assert response.status_code == 200 with freeze_time(now + relativedelta(days=1, minutes=5)): response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": day_lk.key, "organization_id": str(organization.id), @@ -177,7 +177,7 @@ async def test_validate_expiration( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -190,7 +190,7 @@ async def test_validate_expiration( assert month_lk response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": month_lk.key, "organization_id": str(organization.id), @@ -200,7 +200,7 @@ async def test_validate_expiration( with freeze_time(now + relativedelta(days=28, minutes=5)): response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": month_lk.key, "organization_id": str(organization.id), @@ -210,7 +210,7 @@ async def test_validate_expiration( with freeze_time(now + relativedelta(months=1, minutes=5)): response = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": month_lk.key, "organization_id": str(organization.id), @@ -224,7 +224,7 @@ async def test_validate_usage( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -232,7 +232,7 @@ async def test_validate_usage( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -244,7 +244,7 @@ async def test_validate_usage( assert lk increment = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -257,7 +257,7 @@ async def test_validate_usage( assert data.get("usage") == 1 increment = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -270,7 +270,7 @@ async def test_validate_usage( assert data.get("usage") == 9 increment = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -289,7 +289,7 @@ async def test_validate_activation( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -297,13 +297,13 @@ async def test_validate_activation( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=1, enable_user_admin=True + limit=1, enable_customer_admin=True ), ), ) @@ -312,7 +312,7 @@ async def test_validate_activation( assert lk activate = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -325,7 +325,7 @@ async def test_validate_activation( random_id = generate_uuid() activation_404 = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -335,7 +335,7 @@ async def test_validate_activation( assert activation_404.status_code == 404 validate_activation = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -353,7 +353,7 @@ async def test_validate_conditions( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -361,13 +361,13 @@ async def test_validate_conditions( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=1, enable_user_admin=True + limit=1, enable_customer_admin=True ), ), ) @@ -378,7 +378,7 @@ async def test_validate_conditions( conditions = dict(ip="127.0.0.1", fingerprint="sdfsd23:uuojj8:sdfsdf") activate = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -391,7 +391,7 @@ async def test_validate_conditions( activation_id = data["id"] activation_404 = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -401,7 +401,7 @@ async def test_validate_conditions( assert activation_404.status_code == 404 validate_activation = await client.post( - "/v1/users/license-keys/validate", + "/v1/customer-portal/license-keys/validate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -420,7 +420,7 @@ async def test_activation( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -428,13 +428,13 @@ async def test_activation( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=1, enable_user_admin=True + limit=1, enable_customer_admin=True ), ), ) @@ -445,7 +445,7 @@ async def test_activation( label = "test" metadata = {"test": "test"} response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -465,7 +465,7 @@ async def test_unnecessary_activation( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -473,7 +473,7 @@ async def test_unnecessary_activation( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -485,7 +485,7 @@ async def test_unnecessary_activation( assert lk response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -501,7 +501,7 @@ async def test_too_many_activations( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -509,13 +509,13 @@ async def test_too_many_activations( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=1, enable_user_admin=True + limit=1, enable_customer_admin=True ), ), ) @@ -526,7 +526,7 @@ async def test_too_many_activations( label = "test" metadata = {"test": "test"} response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -538,7 +538,7 @@ async def test_too_many_activations( data = response.json() second_response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -554,7 +554,7 @@ async def test_deactivation( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, ) -> None: @@ -562,13 +562,13 @@ async def test_deactivation( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=1, enable_user_admin=True + limit=1, enable_customer_admin=True ), ), ) @@ -577,7 +577,7 @@ async def test_deactivation( assert lk response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -590,7 +590,7 @@ async def test_deactivation( activation_id = data["id"] response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -601,7 +601,7 @@ async def test_deactivation( assert response.status_code == 403 response = await client.post( - "/v1/users/license-keys/deactivate", + "/v1/customer-portal/license-keys/deactivate", json={ "key": lk.key, "organization_id": str(organization.id), @@ -611,7 +611,7 @@ async def test_deactivation( assert response.status_code == 204 response = await client.post( - "/v1/users/license-keys/activate", + "/v1/customer-portal/license-keys/activate", json={ "key": lk.key, "organization_id": str(organization.id), diff --git a/server/tests/customer_portal/service/__init__.py b/server/tests/customer_portal/service/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/user/service/test_benefit.py b/server/tests/customer_portal/service/test_benefit_grant.py similarity index 57% rename from server/tests/user/service/test_benefit.py rename to server/tests/customer_portal/service/test_benefit_grant.py index 5d952acdc2..38561e6bba 100644 --- a/server/tests/user/service/test_benefit.py +++ b/server/tests/customer_portal/service/test_benefit_grant.py @@ -1,12 +1,15 @@ import pytest from polar.auth.models import AuthSubject +from polar.customer_portal.service.benefit_grant import CustomerBenefitGrantSortProperty +from polar.customer_portal.service.benefit_grant import ( + customer_benefit_grant as customer_benefit_grant_service, +) from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams from polar.kit.sorting import Sorting -from polar.models import Benefit, Subscription, User -from polar.user.service.benefit import UserBenefitSortProperty -from polar.user.service.benefit import user_benefit as user_benefit_service +from polar.models import Benefit, Customer, Subscription +from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_benefit_grant @@ -14,45 +17,45 @@ @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestList: - @pytest.mark.auth - async def test_other_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_other_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user_second: User, + customer_second: Customer, ) -> None: await create_benefit_grant( save_fixture, - user_second, + customer_second, benefit_organization, granted=True, subscription=subscription, ) - orders, count = await user_benefit_service.list( + grants, count = await customer_benefit_grant_service.list( session, auth_subject, pagination=PaginationParams(1, 10) ) assert count == 0 - assert len(orders) == 0 + assert len(grants) == 0 - @pytest.mark.auth - async def test_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, benefit_organization_second: Benefit, - user: User, + customer: Customer, ) -> None: await create_benefit_grant( save_fixture, - user, + customer, benefit_organization, granted=True, subscription=subscription, @@ -60,18 +63,18 @@ async def test_user( await create_benefit_grant( save_fixture, - user, + customer, benefit_organization_second, granted=False, subscription=subscription, ) - orders, count = await user_benefit_service.list( + grants, count = await customer_benefit_grant_service.list( session, auth_subject, pagination=PaginationParams(1, 10) ) - assert count == 1 - assert len(orders) == 1 + assert count == 2 + assert len(grants) == 2 @pytest.mark.parametrize( "sorting", @@ -81,103 +84,104 @@ async def test_user( [("organization", False)], ], ) - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_sorting( self, - sorting: list[Sorting[UserBenefitSortProperty]], - auth_subject: AuthSubject[User], + sorting: list[Sorting[CustomerBenefitGrantSortProperty]], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user: User, + customer: Customer, ) -> None: await create_benefit_grant( save_fixture, - user, + customer, benefit_organization, granted=True, subscription=subscription, ) - orders, count = await user_benefit_service.list( + grants, count = await customer_benefit_grant_service.list( session, auth_subject, pagination=PaginationParams(1, 10), sorting=sorting ) assert count == 1 - assert len(orders) == 1 + assert len(grants) == 1 @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestGetById: - @pytest.mark.auth - async def test_other_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_other_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user_second: User, + customer_second: Customer, ) -> None: - await create_benefit_grant( + grant = await create_benefit_grant( save_fixture, - user_second, + customer_second, benefit_organization, granted=True, subscription=subscription, ) - result = await user_benefit_service.get_by_id( - session, auth_subject, benefit_organization.id + result = await customer_benefit_grant_service.get_by_id( + session, auth_subject, grant.id ) assert result is None - @pytest.mark.auth - async def test_user_revoked( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer_revoked( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user: User, + customer: Customer, ) -> None: - await create_benefit_grant( + grant = await create_benefit_grant( save_fixture, - user, + customer, benefit_organization, granted=False, subscription=subscription, ) - result = await user_benefit_service.get_by_id( - session, auth_subject, benefit_organization.id + result = await customer_benefit_grant_service.get_by_id( + session, auth_subject, grant.id ) - assert result is None + assert result is not None + assert result.is_revoked - @pytest.mark.auth - async def test_user_granted( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer_granted( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, subscription: Subscription, benefit_organization: Benefit, - user: User, - user_second: User, + customer: Customer, + customer_second: Customer, ) -> None: - user_grant = await create_benefit_grant( + customer_grant = await create_benefit_grant( save_fixture, - user, + customer, benefit_organization, granted=True, subscription=subscription, ) await create_benefit_grant( save_fixture, - user_second, + customer_second, benefit_organization, granted=True, subscription=subscription, @@ -185,12 +189,9 @@ async def test_user_granted( session.expunge_all() - result = await user_benefit_service.get_by_id( - session, auth_subject, benefit_organization.id + result = await customer_benefit_grant_service.get_by_id( + session, auth_subject, customer_grant.id ) assert result is not None - assert result.id == benefit_organization.id - - assert len(result.grants) == 1 - assert result.grants[0].id == user_grant.id + assert result.id == customer_grant.id diff --git a/server/tests/user/service/test_order.py b/server/tests/customer_portal/service/test_order.py similarity index 50% rename from server/tests/user/service/test_order.py rename to server/tests/customer_portal/service/test_order.py index f8f7a41d50..3bf967d911 100644 --- a/server/tests/user/service/test_order.py +++ b/server/tests/customer_portal/service/test_order.py @@ -1,12 +1,13 @@ import pytest from polar.auth.models import AuthSubject +from polar.customer_portal.service.order import CustomerOrderSortProperty +from polar.customer_portal.service.order import customer_order as customer_order_service from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams from polar.kit.sorting import Sorting -from polar.models import Product, User -from polar.user.service.order import UserOrderSortProperty -from polar.user.service.order import user_order as user_order_service +from polar.models import Customer, Product +from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_order @@ -14,36 +15,36 @@ @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestList: - @pytest.mark.auth - async def test_other_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_other_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, product: Product, - user_second: User, + customer_second: Customer, ) -> None: - await create_order(save_fixture, product=product, user=user_second) + await create_order(save_fixture, product=product, customer=customer_second) - orders, count = await user_order_service.list( + orders, count = await customer_order_service.list( session, auth_subject, pagination=PaginationParams(1, 10) ) assert count == 0 assert len(orders) == 0 - @pytest.mark.auth - async def test_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, ) -> None: - await create_order(save_fixture, product=product, user=user) + await create_order(save_fixture, product=product, customer=customer) - orders, count = await user_order_service.list( + orders, count = await customer_order_service.list( session, auth_subject, pagination=PaginationParams(1, 10) ) @@ -60,19 +61,19 @@ async def test_user( [("subscription", False)], ], ) - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_sorting( self, - sorting: list[Sorting[UserOrderSortProperty]], - auth_subject: AuthSubject[User], + sorting: list[Sorting[CustomerOrderSortProperty]], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, ) -> None: - await create_order(save_fixture, product=product, user=user) + await create_order(save_fixture, product=product, customer=customer) - orders, count = await user_order_service.list( + orders, count = await customer_order_service.list( session, auth_subject, pagination=PaginationParams(1, 10), sorting=sorting ) @@ -83,32 +84,34 @@ async def test_sorting( @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestGetById: - @pytest.mark.auth - async def test_other_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_other_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, product: Product, - user_second: User, + customer_second: Customer, ) -> None: - order = await create_order(save_fixture, product=product, user=user_second) + order = await create_order( + save_fixture, product=product, customer=customer_second + ) - result = await user_order_service.get_by_id(session, auth_subject, order.id) + result = await customer_order_service.get_by_id(session, auth_subject, order.id) assert result is None - @pytest.mark.auth - async def test_user( + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], save_fixture: SaveFixture, session: AsyncSession, product: Product, - user: User, + customer: Customer, ) -> None: - order = await create_order(save_fixture, product=product, user=user) + order = await create_order(save_fixture, product=product, customer=customer) - result = await user_order_service.get_by_id(session, auth_subject, order.id) + result = await customer_order_service.get_by_id(session, auth_subject, order.id) assert result is not None assert result.id == order.id diff --git a/server/tests/user/service/test_subscription.py b/server/tests/customer_portal/service/test_subscription.py similarity index 76% rename from server/tests/user/service/test_subscription.py rename to server/tests/customer_portal/service/test_subscription.py index dd5d639b87..43a2929ad2 100644 --- a/server/tests/user/service/test_subscription.py +++ b/server/tests/customer_portal/service/test_subscription.py @@ -6,23 +6,30 @@ from pytest_mock import MockerFixture from polar.auth.models import AuthSubject +from polar.customer_portal.schemas.subscription import ( + CustomerSubscriptionUpdate, +) +from polar.customer_portal.service.subscription import ( + AlreadyCanceledSubscription, + SubscriptionNotActiveOnStripe, +) +from polar.customer_portal.service.subscription import ( + customer_subscription as customer_subscription_service, +) from polar.exceptions import PolarRequestValidationError from polar.integrations.stripe.service import StripeService from polar.kit.pagination import PaginationParams -from polar.models import Organization, Product, ProductPriceFixed, Subscription, User +from polar.models import ( + Customer, + Organization, + Product, + ProductPriceFixed, + Subscription, +) from polar.models.product_price import ProductPriceType from polar.models.subscription import SubscriptionStatus from polar.postgres import AsyncSession -from polar.user.schemas.subscription import ( - UserSubscriptionUpdate, -) -from polar.user.service.subscription import ( - AlreadyCanceledSubscription, - SubscriptionNotActiveOnStripe, -) -from polar.user.service.subscription import ( - user_subscription as user_subscription_service, -) +from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import ( create_active_subscription, @@ -35,47 +42,47 @@ @pytest.fixture(autouse=True) def stripe_service_mock(mocker: MockerFixture) -> MagicMock: mock = MagicMock(spec=StripeService) - mocker.patch("polar.user.service.subscription.stripe_service", new=mock) + mocker.patch("polar.customer_portal.service.subscription.stripe_service", new=mock) return mock @pytest.mark.asyncio @pytest.mark.skip_db_asserts class TestList: - @pytest.mark.auth + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_valid( self, - auth_subject: AuthSubject[User], + auth_subject: AuthSubject[Customer], session: AsyncSession, save_fixture: SaveFixture, - user: User, - user_second: User, + customer: Customer, + customer_second: Customer, product: Product, product_second: Product, ) -> None: await create_active_subscription( save_fixture, product=product, - user=user, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) await create_active_subscription( save_fixture, product=product_second, - user=user, + customer=customer, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) await create_active_subscription( save_fixture, product=product, - user=user_second, + customer=customer_second, started_at=datetime(2023, 1, 1), ended_at=datetime(2023, 6, 15), ) - results, count = await user_subscription_service.list( + results, count = await customer_subscription_service.list( session, auth_subject, pagination=PaginationParams(1, 10), @@ -92,10 +99,10 @@ async def test_not_existing_product( self, session: AsyncSession, subscription: Subscription ) -> None: with pytest.raises(PolarRequestValidationError): - await user_subscription_service.update( + await customer_subscription_service.update( session, subscription=subscription, - subscription_update=UserSubscriptionUpdate( + subscription_update=CustomerSubscriptionUpdate( product_price_id=uuid.uuid4() ), ) @@ -112,10 +119,12 @@ async def test_not_recurring_price( save_fixture, product=product, type=ProductPriceType.one_time ) with pytest.raises(PolarRequestValidationError): - await user_subscription_service.update( + await customer_subscription_service.update( session, subscription=subscription, - subscription_update=UserSubscriptionUpdate(product_price_id=price.id), + subscription_update=CustomerSubscriptionUpdate( + product_price_id=price.id + ), ) async def test_extraneous_tier( @@ -125,10 +134,10 @@ async def test_extraneous_tier( product_organization_second: Product, ) -> None: with pytest.raises(PolarRequestValidationError): - await user_subscription_service.update( + await customer_subscription_service.update( session, subscription=subscription, - subscription_update=UserSubscriptionUpdate( + subscription_update=CustomerSubscriptionUpdate( product_price_id=product_organization_second.all_prices[0].id ), ) @@ -138,10 +147,10 @@ async def test_not_existing_stripe_subscription( ) -> None: subscription.stripe_subscription_id = None with pytest.raises(SubscriptionNotActiveOnStripe): - await user_subscription_service.update( + await customer_subscription_service.update( session, subscription=subscription, - subscription_update=UserSubscriptionUpdate( + subscription_update=CustomerSubscriptionUpdate( product_price_id=product_second.prices[0].id ), ) @@ -155,10 +164,10 @@ async def test_valid( product_second: Product, ) -> None: new_price = product_second.prices[0] - updated_subscription = await user_subscription_service.update( + updated_subscription = await customer_subscription_service.update( session, subscription=subscription, - subscription_update=UserSubscriptionUpdate( + subscription_update=CustomerSubscriptionUpdate( product_price_id=product_second.prices[0].id, ), ) @@ -187,17 +196,19 @@ async def test_already_canceled( save_fixture: SaveFixture, subscription: Subscription, product: Product, - user: User, + customer: Customer, ) -> None: subscription = await create_subscription( save_fixture, product=product, - user=user, + customer=customer, status=SubscriptionStatus.canceled, ) with pytest.raises(AlreadyCanceledSubscription): - await user_subscription_service.cancel(session, subscription=subscription) + await customer_subscription_service.cancel( + session, subscription=subscription + ) @pytest.mark.auth async def test_cancel_at_period_end( @@ -206,16 +217,18 @@ async def test_cancel_at_period_end( save_fixture: SaveFixture, subscription: Subscription, product: Product, - user: User, + customer: Customer, ) -> None: subscription = await create_active_subscription( - save_fixture, product=product, user=user + save_fixture, product=product, customer=customer ) subscription.cancel_at_period_end = True await save_fixture(subscription) with pytest.raises(AlreadyCanceledSubscription): - await user_subscription_service.cancel(session, subscription=subscription) + await customer_subscription_service.cancel( + session, subscription=subscription + ) @pytest.mark.auth async def test_free_subscription( @@ -224,16 +237,16 @@ async def test_free_subscription( save_fixture: SaveFixture, stripe_service_mock: MagicMock, product: Product, - user: User, + customer: Customer, ) -> None: subscription = await create_active_subscription( save_fixture, product=product, - user=user, + customer=customer, stripe_subscription_id=None, ) - updated_subscription = await user_subscription_service.cancel( + updated_subscription = await customer_subscription_service.cancel( session, subscription=subscription ) @@ -251,15 +264,15 @@ async def test_stripe_subscription( save_fixture: SaveFixture, stripe_service_mock: MagicMock, product: Product, - user: User, + customer: Customer, ) -> None: subscription = await create_active_subscription( save_fixture, product=product, - user=user, + customer=customer, ) - updated_subscription = await user_subscription_service.cancel( + updated_subscription = await customer_subscription_service.cancel( session, subscription=subscription ) diff --git a/server/tests/file/test_endpoints.py b/server/tests/file/test_endpoints.py index c8706b6805..0e85419360 100644 --- a/server/tests/file/test_endpoints.py +++ b/server/tests/file/test_endpoints.py @@ -29,7 +29,7 @@ async def test_create_downloadable_without_scopes( ) -> None: response = await client.post( "/v1/files/", - json=logo_png.build_create_json(organization.id), + json=logo_png.build_create(organization.id), ) assert response.status_code == 403 diff --git a/server/tests/fixtures/auth.py b/server/tests/fixtures/auth.py index e15a105e88..6ff9c6d233 100644 --- a/server/tests/fixtures/auth.py +++ b/server/tests/fixtures/auth.py @@ -4,8 +4,7 @@ from polar.auth.models import Anonymous, AuthMethod, AuthSubject, Subject from polar.auth.scope import Scope -from polar.models import User -from polar.models.organization import Organization +from polar.models import Customer, Organization, User class AuthSubjectFixture: @@ -20,6 +19,7 @@ def __init__( "organization", "organization_second", "organization_blocked", + "customer", ] = "user", scopes: set[Scope] = {Scope.web_default}, method: AuthMethod = AuthMethod.COOKIE, @@ -47,6 +47,7 @@ def auth_subject( organization: Organization, organization_second: Organization, organization_blocked: Organization, + customer: Customer, ) -> AuthSubject[Subject]: """ This fixture generates an AuthSubject instance used by the `client` fixture @@ -57,7 +58,7 @@ def auth_subject( See `pytest_generate_tests` below for more information. """ auth_subject_fixture: AuthSubjectFixture = request.param - subjects_map: dict[str, Anonymous | User | Organization] = { + subjects_map: dict[str, Anonymous | Customer | User | Organization] = { "anonymous": Anonymous(), "user": user, "user_second": user_second, @@ -65,6 +66,7 @@ def auth_subject( "organization": organization, "organization_second": organization_second, "organization_blocked": organization_blocked, + "customer": customer, } return AuthSubject( subjects_map[auth_subject_fixture.subject], diff --git a/server/tests/fixtures/file.py b/server/tests/fixtures/file.py index 2c556e9515..9665d1e6c5 100644 --- a/server/tests/fixtures/file.py +++ b/server/tests/fixtures/file.py @@ -15,16 +15,10 @@ from httpx import AsyncClient, Response from minio import Minio -from polar.auth.models import AuthSubject from polar.config import settings from polar.file.s3 import S3_SERVICES -from polar.file.schemas import ( - DownloadableFileCreate, - FileRead, - FileReadAdapter, - FileUpload, - FileUploadCompleted, -) +from polar.file.schemas import DownloadableFileCreate, FileUpload, FileUploadCompleted +from polar.file.service import file as file_service from polar.integrations.aws.s3.schemas import ( S3FileCreateMultipart, S3FileCreatePart, @@ -32,8 +26,9 @@ S3FileUploadCompletedPart, S3FileUploadPart, ) -from polar.models import Organization, User, UserOrganization +from polar.models import File, Organization from polar.models.file import FileServiceTypes +from polar.postgres import AsyncSession pwd = Path(__file__).parent.absolute() @@ -85,23 +80,23 @@ def get_chunk(self, part: S3FileUploadPart) -> bytes: ##################################################################### async def create( - self, client: AsyncClient, organization_id: UUID, parts: int = 1 + self, session: AsyncSession, organization: Organization, parts: int = 1 ) -> FileUpload: - response = await client.post( - "/v1/files/", - json=self.build_create_json(organization_id, parts=parts), + return await file_service.generate_presigned_upload( + session, + organization=organization, + create_schema=self.build_create(organization.id, parts=parts), ) - return self.validate_create_response(response, organization_id) - def build_create_json( + def build_create( self, organization_id: UUID, parts: int = 1 - ) -> dict[str, Any]: + ) -> DownloadableFileCreate: create_parts = [] for i in range(parts): part = self.build_create_part(i + 1, parts) create_parts.append(part) - create = DownloadableFileCreate( + return DownloadableFileCreate( service=FileServiceTypes.downloadable, organization_id=organization_id, name=self.name, @@ -110,8 +105,6 @@ def build_create_json( checksum_sha256_base64=self.base64, upload=S3FileCreateMultipart(parts=create_parts), ) - data = create.model_dump(mode="json") - return data def build_create_part(self, number: int, parts: int) -> S3FileCreatePart: chunk_size = self.size // parts @@ -204,24 +197,20 @@ async def put_upload( async def complete( self, - client: AsyncClient, + session: AsyncSession, created: FileUpload, uploaded: list[S3FileUploadCompletedPart], - ) -> FileRead: - payload = FileUploadCompleted( - id=created.upload.id, path=created.path, parts=uploaded - ) - payload_json = payload.model_dump(mode="json") - - response = await client.post( - f"/v1/files/{created.id}/uploaded", - json=payload_json, + ) -> File: + file = await file_service.get(session, created.id) + assert file is not None + completed = await file_service.complete_upload( + session, + file=file, + completed_schema=FileUploadCompleted( + id=created.upload.id, path=created.path, parts=uploaded + ), ) - assert response.status_code == 200 - data = response.json() - completed = FileReadAdapter.validate_python(data) - assert completed.id == created.id assert completed.is_uploaded is True s3_service = S3_SERVICES[completed.service] @@ -293,13 +282,13 @@ def empty_test_bucket(worker_id: str) -> Iterable[Any]: async def uploaded_fixture( - client: AsyncClient, - organization_id: UUID, + session: AsyncSession, + organization: Organization, file: TestFile, -) -> FileRead: - created = await file.create(client, organization_id) +) -> File: + created = await file.create(session, organization) uploaded = await file.upload(created) - completed = await file.complete(client, created, uploaded) + completed = await file.complete(session, created, uploaded) return completed @@ -314,14 +303,9 @@ def non_ascii_file_name() -> TestFile: @pytest_asyncio.fixture -async def uploaded_logo_png( - client: AsyncClient, - auth_subject: AuthSubject[User], - user_organization: UserOrganization, - organization: Organization, -) -> FileRead: +async def uploaded_logo_png(session: AsyncSession, organization: Organization) -> File: img = TestFile("logo.png") - return await uploaded_fixture(client, user_organization.organization_id, img) + return await uploaded_fixture(session, organization, img) @pytest.fixture @@ -330,14 +314,9 @@ def logo_jpg() -> TestFile: @pytest_asyncio.fixture -async def uploaded_logo_jpg( - client: AsyncClient, - auth_subject: AuthSubject[User], - user_organization: UserOrganization, - organization: Organization, -) -> FileRead: +async def uploaded_logo_jpg(session: AsyncSession, organization: Organization) -> File: img = TestFile("logo.jpg") - return await uploaded_fixture(client, user_organization.organization_id, img) + return await uploaded_fixture(session, organization, img) @pytest.fixture diff --git a/server/tests/fixtures/license_key.py b/server/tests/fixtures/license_key.py index ce97f0f2c9..cb0139296b 100644 --- a/server/tests/fixtures/license_key.py +++ b/server/tests/fixtures/license_key.py @@ -3,7 +3,7 @@ from polar.benefit.benefits.license_keys import BenefitLicenseKeysService from polar.benefit.schemas import BenefitLicenseKeysCreateProperties -from polar.models import Benefit, LicenseKey, Organization, Product, User +from polar.models import Benefit, Customer, LicenseKey, Organization, Product from polar.models.benefit import BenefitLicenseKeys, BenefitType from polar.models.benefit_grant import BenefitGrantLicenseKeysProperties from polar.models.subscription import SubscriptionStatus @@ -24,7 +24,7 @@ async def create_benefit_and_grant( session: AsyncSession, redis: Redis, save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product: Product, properties: BenefitLicenseKeysCreateProperties, @@ -40,7 +40,7 @@ async def create_benefit_and_grant( redis, save_fixture, cast(BenefitLicenseKeys, benefit), - user=user, + customer=customer, product=product, ) @@ -51,22 +51,22 @@ async def create_grant( redis: Redis, save_fixture: SaveFixture, benefit: BenefitLicenseKeys, - user: User, + customer: Customer, product: Product, ) -> tuple[BenefitLicenseKeys, BenefitGrantLicenseKeysProperties]: subscription = await create_subscription( save_fixture, product=product, - user=user, + customer=customer, status=SubscriptionStatus.active, ) await create_benefit_grant( save_fixture, - user, + customer, benefit, subscription=subscription, ) - return await cls.run_grant_task(session, redis, benefit, user) + return await cls.run_grant_task(session, redis, benefit, customer) @classmethod async def run_grant_task( @@ -74,10 +74,10 @@ async def run_grant_task( session: AsyncSession, redis: Redis, benefit: BenefitLicenseKeys, - user: User, + customer: Customer, ) -> tuple[BenefitLicenseKeys, BenefitGrantLicenseKeysProperties]: service = BenefitLicenseKeysService(session, redis) - granted = await service.grant(benefit, user, {}) + granted = await service.grant(benefit, customer, {}) return benefit, granted @classmethod @@ -86,21 +86,21 @@ async def run_revoke_task( session: AsyncSession, redis: Redis, benefit: BenefitLicenseKeys, - user: User, + customer: Customer, ) -> tuple[BenefitLicenseKeys, BenefitGrantLicenseKeysProperties]: service = BenefitLicenseKeysService(session, redis) - revoked = await service.revoke(benefit, user, {}) + revoked = await service.revoke(benefit, customer, {}) return benefit, revoked @classmethod - async def get_user_licenses( - cls, session: AsyncSession, user: User + async def get_customer_licenses( + cls, session: AsyncSession, customer: Customer ) -> Sequence[LicenseKey]: statement = ( sql.select(LicenseKey) .join(Benefit) .where( - LicenseKey.user_id == user.id, + LicenseKey.customer_id == customer.id, Benefit.deleted_at.is_(None), ) ) diff --git a/server/tests/license_key/test_endpoints.py b/server/tests/license_key/test_endpoints.py index cf64060869..745ef1e269 100644 --- a/server/tests/license_key/test_endpoints.py +++ b/server/tests/license_key/test_endpoints.py @@ -226,7 +226,7 @@ async def test_get_activation( properties=BenefitLicenseKeysCreateProperties( prefix="testing", activations=BenefitLicenseKeyActivationProperties( - limit=2, enable_user_admin=True + limit=2, enable_customer_admin=True ), ), ) From cdbf51ac51dd3f45d1746ef857400ee5fc894577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 5 Dec 2024 17:04:30 +0100 Subject: [PATCH 12/47] server: reintroduce Transaction.payment_user for backward compatibility with Pledge --- server/polar/authz/service.py | 35 --------- .../polar/benefit/benefits/downloadables.py | 4 +- server/polar/benefit/endpoints.py | 15 +--- server/polar/customer/service.py | 2 +- server/polar/models/transaction.py | 17 +++++ server/polar/transaction/endpoints.py | 6 +- server/polar/transaction/service/dispute.py | 6 +- server/polar/transaction/service/payment.py | 6 +- server/polar/transaction/service/refund.py | 3 +- .../polar/transaction/service/transaction.py | 11 ++- server/polar/user/service/advertisement.py | 7 +- .../benefit/benefits/test_downloadables.py | 2 +- server/tests/custom_field/test_service.py | 9 ++- server/tests/customer/test_service.py | 4 +- server/tests/file/test_endpoints.py | 73 ++++--------------- server/tests/fixtures/random_objects.py | 2 +- server/tests/license_key/test_endpoints.py | 24 +++--- server/tests/metrics/test_service.py | 18 +++-- server/tests/transaction/conftest.py | 8 +- .../tests/transaction/service/test_payment.py | 18 ++--- .../transaction/service/test_transaction.py | 5 +- 21 files changed, 114 insertions(+), 161 deletions(-) diff --git a/server/polar/authz/service.py b/server/polar/authz/service.py index 932059431b..33d099e5ca 100644 --- a/server/polar/authz/service.py +++ b/server/polar/authz/service.py @@ -12,7 +12,6 @@ from polar.issue.service import issue as issue_service from polar.models.account import Account from polar.models.benefit import Benefit -from polar.models.downloadable import Downloadable, DownloadableStatus from polar.models.external_organization import ExternalOrganization from polar.models.issue import Issue from polar.models.issue_reward import IssueReward @@ -21,7 +20,6 @@ from polar.models.pledge import Pledge from polar.models.product import Product from polar.models.repository import Repository -from polar.models.subscription import Subscription from polar.models.user import User from polar.models.webhook_endpoint import WebhookEndpoint from polar.postgres import AsyncSession, get_db_session @@ -47,9 +45,7 @@ class AccessType(StrEnum): | Pledge | Product | Benefit - | Subscription | WebhookEndpoint - | Downloadable | LicenseKey ) @@ -251,16 +247,6 @@ async def can( if isinstance(subject, Organization): return object.organization_id == subject.id - # - # Subscription - # - if ( - isinstance(subject, User) - and accessType == AccessType.write - and isinstance(object, Subscription) - ): - return object.user_id == subject.id - # # WebhookEndpoint # @@ -272,16 +258,6 @@ async def can( return object.organization_id == subject.id # - # Downloadable - # - - if ( - isinstance(subject, User) - and accessType == AccessType.read - and isinstance(object, Downloadable) - ): - return await self._can_user_download_file(subject, object) - # # License Key # if isinstance(object, LicenseKey): @@ -536,17 +512,6 @@ async def _can_user_write_pledge(self, subject: User, object: Pledge) -> bool: return False - # - # Downloadable - # - async def _can_user_download_file( - self, subject: User, object: Downloadable - ) -> bool: - if subject.id != object.user_id: - return False - - return object.status == DownloadableStatus.granted.value - # # WebhookEndpoint # diff --git a/server/polar/benefit/benefits/downloadables.py b/server/polar/benefit/benefits/downloadables.py index 677b3f996c..0f1d318ad1 100644 --- a/server/polar/benefit/benefits/downloadables.py +++ b/server/polar/benefit/benefits/downloadables.py @@ -7,11 +7,13 @@ from polar.auth.models import AuthSubject from polar.benefit import schemas as benefit_schemas +from polar.customer_portal.service.downloadables import ( + downloadable as downloadable_service, +) from polar.logging import Logger from polar.models import Customer, Organization, User from polar.models.benefit import BenefitDownloadables, BenefitDownloadablesProperties from polar.models.benefit_grant import BenefitGrantDownloadablesProperties -from polar.user.service.downloadables import downloadable as downloadable_service from .base import ( BenefitServiceProtocol, diff --git a/server/polar/benefit/endpoints.py b/server/polar/benefit/endpoints.py index fa48a7747c..efa184edea 100644 --- a/server/polar/benefit/endpoints.py +++ b/server/polar/benefit/endpoints.py @@ -100,16 +100,8 @@ async def grants( "If `false`, only revoked benefits will be returned. " ), ), - user_id: UUID4 | None = Query( - None, - description=("Filter by user ID."), - ), - github_user_id: int | None = Query( - None, - description=( - "Filter by GitHub user ID. " - "Only available for users who have linked their GitHub account on Polar." - ), + customer_id: MultipleQueryFilter[UUID4] | None = Query( + None, title="CustomerID Filter", description="Filter by customer." ), session: AsyncSession = Depends(get_db_session), ) -> ListResource[BenefitGrant]: @@ -127,8 +119,7 @@ async def grants( session, benefit, is_granted=is_granted, - user_id=user_id, - github_user_id=github_user_id, + customer_id=customer_id, pagination=pagination, ) diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index f667adbacb..7bfb04b9a8 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -43,7 +43,7 @@ async def list( order_by_clauses.append(clause_function(Customer.created_at)) elif criterion == CustomerSortProperty.email: order_by_clauses.append(clause_function(Customer.email)) - elif criterion == CustomerSortProperty.name: + elif criterion == CustomerSortProperty.customer_name: order_by_clauses.append(clause_function(Customer.name)) statement = statement.order_by(*order_by_clauses) diff --git a/server/polar/models/transaction.py b/server/polar/models/transaction.py index f289d2faec..ef96b885c2 100644 --- a/server/polar/models/transaction.py +++ b/server/polar/models/transaction.py @@ -15,6 +15,7 @@ Order, Organization, Pledge, + User, ) @@ -287,6 +288,22 @@ def payment_customer(cls) -> Mapped["Customer | None"]: def payment_organization(cls) -> Mapped["Organization | None"]: return relationship("Organization", lazy="raise") + payment_user_id: Mapped[UUID | None] = mapped_column( + Uuid, + ForeignKey("users.id", ondelete="set null"), + nullable=True, + index=True, + ) + """ + ID of the `User` who made the payment. + + Used for pledges. Orders and subscriptions should use `payment_customer_id`. + """ + + @declared_attr + def payment_user(cls) -> Mapped["User | None"]: + return relationship("User", lazy="raise") + pledge_id: Mapped[UUID | None] = mapped_column( Uuid, ForeignKey("pledges.id", ondelete="set null"), diff --git a/server/polar/transaction/endpoints.py b/server/polar/transaction/endpoints.py index 7e20f59002..cf4c1bd8b9 100644 --- a/server/polar/transaction/endpoints.py +++ b/server/polar/transaction/endpoints.py @@ -44,8 +44,9 @@ async def search_transactions( auth_subject: WebUser, type: TransactionType | None = Query(None), account_id: UUID4 | None = Query(None), - payment_user_id: UUID4 | None = Query(None), + payment_customer_id: UUID4 | None = Query(None), payment_organization_id: UUID4 | None = Query(None), + payment_user_id: UUID4 | None = Query(None), exclude_platform_fees: bool = Query(False), session: AsyncSession = Depends(get_db_session), ) -> ListResource[Transaction]: @@ -54,8 +55,9 @@ async def search_transactions( auth_subject.subject, type=type, account_id=account_id, - payment_user_id=payment_user_id, + payment_customer_id=payment_customer_id, payment_organization_id=payment_organization_id, + payment_user_id=payment_user_id, exclude_platform_fees=exclude_platform_fees, pagination=pagination, sorting=sorting, diff --git a/server/polar/transaction/service/dispute.py b/server/polar/transaction/service/dispute.py index 6682428767..a0b92b9e80 100644 --- a/server/polar/transaction/service/dispute.py +++ b/server/polar/transaction/service/dispute.py @@ -64,8 +64,9 @@ async def create_dispute( customer_id=payment_transaction.customer_id, charge_id=charge_id, dispute_id=dispute.id, - payment_user_id=payment_transaction.payment_user_id, + payment_customer_id=payment_transaction.payment_customer_id, payment_organization_id=payment_transaction.payment_organization_id, + payment_user_id=payment_transaction.payment_user_id, pledge_id=payment_transaction.pledge_id, issue_reward_id=payment_transaction.issue_reward_id, order_id=payment_transaction.order_id, @@ -122,8 +123,9 @@ async def create_dispute_reversal( customer_id=payment_transaction.customer_id, charge_id=charge_id, dispute_id=dispute.id, - payment_user_id=payment_transaction.payment_user_id, + payment_customer_id=payment_transaction.payment_customer_id, payment_organization_id=payment_transaction.payment_organization_id, + payment_user_id=payment_transaction.payment_user_id, pledge_id=payment_transaction.pledge_id, issue_reward_id=payment_transaction.issue_reward_id, order_id=payment_transaction.order_id, diff --git a/server/polar/transaction/service/payment.py b/server/polar/transaction/service/payment.py index 746541c414..20c670b2e6 100644 --- a/server/polar/transaction/service/payment.py +++ b/server/polar/transaction/service/payment.py @@ -7,7 +7,7 @@ from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import stripe as stripe_service from polar.integrations.stripe.utils import get_expandable_id -from polar.models import Pledge, Transaction +from polar.models import Pledge, Transaction, User from polar.models.transaction import PaymentProcessor, TransactionType from polar.organization.service import organization as organization_service from polar.pledge.service import pledge as pledge_service @@ -86,6 +86,7 @@ async def create_payment( pledge_invoice = True # Try to link with a Pledge + payment_user: User | None = None if pledge_invoice or charge.metadata.get("type") == ProductType.pledge: assert charge.payment_intent is not None payment_intent = get_expandable_id(charge.payment_intent) @@ -97,7 +98,7 @@ async def create_payment( # link from the pledge data. Happens for anonymous pledges. if payment_customer is None and payment_organization is None: await session.refresh(pledge, {"user", "by_organization"}) - payment_customer = None # TODO: Pledge customers? + payment_user = pledge.user payment_organization = pledge.by_organization risk = getattr(charge, "outcome", {}) @@ -114,6 +115,7 @@ async def create_payment( customer_id=customer_id, payment_customer=payment_customer, payment_organization=payment_organization, + payment_user=payment_user, charge_id=charge.id, pledge=pledge, risk_level=risk.get("risk_level"), diff --git a/server/polar/transaction/service/refund.py b/server/polar/transaction/service/refund.py index 279c71e68b..f3f55d547a 100644 --- a/server/polar/transaction/service/refund.py +++ b/server/polar/transaction/service/refund.py @@ -74,8 +74,9 @@ async def create_refunds( customer_id=payment_transaction.customer_id, charge_id=charge.id, refund_id=refund.id, - payment_user_id=payment_transaction.payment_user_id, + payment_customer_id=payment_transaction.payment_customer_id, payment_organization_id=payment_transaction.payment_organization_id, + payment_user_id=payment_transaction.payment_user_id, pledge_id=payment_transaction.pledge_id, issue_reward_id=payment_transaction.issue_reward_id, order_id=payment_transaction.order_id, diff --git a/server/polar/transaction/service/transaction.py b/server/polar/transaction/service/transaction.py index ffffc4e05b..1d684ca8ea 100644 --- a/server/polar/transaction/service/transaction.py +++ b/server/polar/transaction/service/transaction.py @@ -45,8 +45,9 @@ async def search( *, type: TransactionType | None = None, account_id: uuid.UUID | None = None, - payment_user_id: uuid.UUID | None = None, + payment_customer_id: uuid.UUID | None = None, payment_organization_id: uuid.UUID | None = None, + payment_user_id: uuid.UUID | None = None, exclude_platform_fees: bool = False, pagination: PaginationParams, sorting: list[Sorting[TransactionSortProperty]] = [ @@ -79,12 +80,16 @@ async def search( statement = statement.where(Transaction.type == type) if account_id is not None: statement = statement.where(Transaction.account_id == account_id) - if payment_user_id is not None: - statement = statement.where(Transaction.payment_user_id == payment_user_id) + if payment_customer_id is not None: + statement = statement.where( + Transaction.payment_customer_id == payment_customer_id + ) if payment_organization_id is not None: statement = statement.where( Transaction.payment_organization_id == payment_organization_id ) + if payment_user_id is not None: + statement = statement.where(Transaction.payment_user_id == payment_user_id) if exclude_platform_fees: statement = statement.where(Transaction.platform_fee_type.is_(None)) diff --git a/server/polar/user/service/advertisement.py b/server/polar/user/service/advertisement.py index 6b59fe3a0c..409376ec9b 100644 --- a/server/polar/user/service/advertisement.py +++ b/server/polar/user/service/advertisement.py @@ -6,6 +6,9 @@ from sqlalchemy import Select, UnaryExpression, asc, desc, select, update from polar.auth.models import AuthSubject +from polar.customer_portal.service.benefit_grant import ( + customer_benefit_grant as customer_benefit_grant_service, +) from polar.exceptions import PolarError, PolarRequestValidationError from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams, paginate @@ -19,7 +22,6 @@ UserAdvertisementCampaignEnable, UserAdvertisementCampaignUpdate, ) -from .benefit import user_benefit as user_benefit_service class UserAdvertisementError(PolarError): ... @@ -111,9 +113,10 @@ async def enable( advertisement_campaign: AdvertisementCampaign, advertisement_campaign_enable: UserAdvertisementCampaignEnable, ) -> Sequence[BenefitGrant]: - benefit = await user_benefit_service.get_by_id( + grant = await customer_benefit_grant_service.get_by_id( session, auth_subject, advertisement_campaign_enable.benefit_id ) + benefit = grant.benefit if benefit is None: raise PolarRequestValidationError( diff --git a/server/tests/benefit/benefits/test_downloadables.py b/server/tests/benefit/benefits/test_downloadables.py index fbf9214ea5..0dcd2592ef 100644 --- a/server/tests/benefit/benefits/test_downloadables.py +++ b/server/tests/benefit/benefits/test_downloadables.py @@ -13,7 +13,7 @@ @pytest.mark.asyncio -@pytest.mark.http_auto_expunge +@pytest.mark.skip_db_asserts class TestDownloadblesBenefit: @pytest.mark.auth async def test_grant_one( diff --git a/server/tests/custom_field/test_service.py b/server/tests/custom_field/test_service.py index 6f04db88b5..0e25a6e778 100644 --- a/server/tests/custom_field/test_service.py +++ b/server/tests/custom_field/test_service.py @@ -3,7 +3,7 @@ from polar.custom_field.schemas import CustomFieldUpdateText from polar.custom_field.service import custom_field as custom_field_service -from polar.models import Order, Organization, Product, User +from polar.models import Customer, Order, Organization, Product from polar.models.custom_field import CustomFieldText, CustomFieldType from polar.postgres import AsyncSession from tests.fixtures.database import SaveFixture @@ -24,14 +24,17 @@ async def text_field( @pytest_asyncio.fixture async def order_text_field_data( - save_fixture: SaveFixture, product: Product, user: User, text_field: CustomFieldText + save_fixture: SaveFixture, + product: Product, + customer: Customer, + text_field: CustomFieldText, ) -> Order: custom_field_data = {"foo": "bar"} custom_field_data[text_field.slug] = "text1" return await create_order( save_fixture, product=product, - user=user, + customer=customer, custom_field_data=custom_field_data, ) diff --git a/server/tests/customer/test_service.py b/server/tests/customer/test_service.py index 4fe0eb892b..efbd2fd7ad 100644 --- a/server/tests/customer/test_service.py +++ b/server/tests/customer/test_service.py @@ -76,7 +76,7 @@ async def test_valid( organization: Organization, user_organization: UserOrganization, ) -> None: - payload: dict[str, Any] = {"email": "customer@example.com"} + payload: dict[str, Any] = {"email": "customer.new@example.com"} if is_user(auth_subject): payload["organization_id"] = str(organization.id) @@ -88,7 +88,7 @@ async def test_valid( ) await session.flush() - assert customer.email == "customer@example.com" + assert customer.email == "customer.new@example.com" @pytest.mark.asyncio diff --git a/server/tests/file/test_endpoints.py b/server/tests/file/test_endpoints.py index 0e85419360..116548d9fd 100644 --- a/server/tests/file/test_endpoints.py +++ b/server/tests/file/test_endpoints.py @@ -3,18 +3,16 @@ import pytest from httpx import AsyncClient, ReadError -from polar.auth.models import AuthSubject from polar.file.s3 import S3_SERVICES from polar.file.service import file as file_service from polar.integrations.aws.s3.exceptions import S3FileError -from polar.models import Organization, User, UserOrganization +from polar.models import Organization from polar.postgres import AsyncSession -from tests.fixtures.auth import AuthSubjectFixture from tests.fixtures.file import TestFile @pytest.mark.asyncio -@pytest.mark.http_auto_expunge +@pytest.mark.skip_db_asserts class TestEndpoints: async def test_anonymous_create_401( self, client: AsyncClient, organization: Organization @@ -29,51 +27,31 @@ async def test_create_downloadable_without_scopes( ) -> None: response = await client.post( "/v1/files/", - json=logo_png.build_create(organization.id), + json=logo_png.build_create(organization.id).model_dump(mode="json"), ) assert response.status_code == 403 - @pytest.mark.http_auto_expunge - @pytest.mark.auth( - AuthSubjectFixture(subject="user"), - ) async def test_create_downloadable_with_web_scope( - self, - client: AsyncClient, - auth_subject: AuthSubject[User], - session: AsyncSession, - user_organization: UserOrganization, - logo_png: TestFile, + self, session: AsyncSession, organization: Organization, logo_png: TestFile ) -> None: - organization_id = user_organization.organization_id - await logo_png.create(client, organization_id) + await logo_png.create(session, organization) - @pytest.mark.http_auto_expunge - @pytest.mark.auth async def test_create_downloadable_with_non_ascii_name( self, - client: AsyncClient, - user_organization: UserOrganization, + session: AsyncSession, + organization: Organization, non_ascii_file_name: TestFile, ) -> None: - organization_id = user_organization.organization_id - await non_ascii_file_name.create(client, organization_id) + await non_ascii_file_name.create(session, organization) - @pytest.mark.http_auto_expunge - @pytest.mark.auth( - AuthSubjectFixture(subject="user"), - ) async def test_incomplete_upload_with_web_scope( self, - client: AsyncClient, - auth_subject: AuthSubject[User], session: AsyncSession, - user_organization: UserOrganization, + organization: Organization, logo_jpg: TestFile, ) -> None: - organization_id = user_organization.organization_id - created = await logo_jpg.create(client, organization_id) + created = await logo_jpg.create(session, organization) await logo_jpg.upload(created) @@ -86,21 +64,10 @@ async def test_incomplete_upload_with_web_scope( assert record assert record.is_uploaded is False - @pytest.mark.http_auto_expunge - @pytest.mark.auth( - AuthSubjectFixture(subject="user"), - ) async def test_upload_without_signature( - self, - client: AsyncClient, - auth_subject: AuthSubject[User], - session: AsyncSession, - user_organization: UserOrganization, - logo_jpg: TestFile, + self, session: AsyncSession, organization: Organization, logo_jpg: TestFile ) -> None: - organization_id = user_organization.organization_id - - created = await logo_jpg.create(client, organization_id) + created = await logo_jpg.create(session, organization) part = created.upload.parts[0] @@ -145,22 +112,12 @@ async def test_upload_without_signature( assert record assert record.is_uploaded is False - @pytest.mark.http_auto_expunge - @pytest.mark.auth( - AuthSubjectFixture(subject="user"), - ) async def test_upload_with_web_scope( - self, - client: AsyncClient, - auth_subject: AuthSubject[User], - session: AsyncSession, - user_organization: UserOrganization, - logo_jpg: TestFile, + self, session: AsyncSession, organization: Organization, logo_jpg: TestFile ) -> None: - organization_id = user_organization.organization_id - created = await logo_jpg.create(client, organization_id) + created = await logo_jpg.create(session, organization) uploaded = await logo_jpg.upload(created) - await logo_jpg.complete(client, created, uploaded) + await logo_jpg.complete(session, created, uploaded) record = await file_service.get(session, created.id, allow_deleted=True) assert record diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 9b0d6ae5f1..135c123116 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -839,7 +839,7 @@ async def create_customer( organization: Organization, email: str = "customer@example.com", email_verified: bool = False, - name="Customer", + name: str = "Customer", stripe_customer_id: str = "STRIPE_CUSTOMER_ID", ) -> Customer: customer = Customer( diff --git a/server/tests/license_key/test_endpoints.py b/server/tests/license_key/test_endpoints.py index 745ef1e269..3c6b06a6e7 100644 --- a/server/tests/license_key/test_endpoints.py +++ b/server/tests/license_key/test_endpoints.py @@ -12,7 +12,7 @@ from polar.kit.pagination import PaginationParams from polar.kit.utils import generate_uuid, utc_now from polar.license_key.service import license_key as license_key_service -from polar.models import Organization, Product, User, UserOrganization +from polar.models import Customer, Organization, Product, User, UserOrganization from polar.postgres import AsyncSession from polar.redis import Redis from tests.fixtures.auth import AuthSubjectFixture @@ -42,15 +42,15 @@ async def test_get_unauthorized_401( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, organization: Organization, product: Product, + customer: Customer, ) -> None: benefit, granted = await TestLicenseKey.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -74,16 +74,16 @@ async def test_get_authorized( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, user_organization: UserOrganization, organization: Organization, product: Product, + customer: Customer, ) -> None: benefit, granted = await TestLicenseKey.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -110,16 +110,16 @@ async def test_update( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, user_organization: UserOrganization, organization: Organization, product: Product, + customer: Customer, ) -> None: benefit, granted = await TestLicenseKey.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -159,16 +159,16 @@ async def test_list( client: AsyncClient, save_fixture: SaveFixture, auth_subject: AuthSubject[User | Organization], - user: User, user_organization: UserOrganization, organization: Organization, product: Product, + customer: Customer, ) -> None: benefit, granted = await TestLicenseKey.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -179,7 +179,7 @@ async def test_list( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( @@ -211,16 +211,16 @@ async def test_get_activation( redis: Redis, client: AsyncClient, save_fixture: SaveFixture, - user: User, user_organization: UserOrganization, organization: Organization, product: Product, + customer: Customer, ) -> None: benefit, granted = await TestLicenseKey.create_benefit_and_grant( session, redis, save_fixture, - user=user, + customer=customer, organization=organization, product=product, properties=BenefitLicenseKeysCreateProperties( diff --git a/server/tests/metrics/test_service.py b/server/tests/metrics/test_service.py index 1554bb7e43..e9bb6d2144 100644 --- a/server/tests/metrics/test_service.py +++ b/server/tests/metrics/test_service.py @@ -9,6 +9,7 @@ from polar.metrics.queries import Interval from polar.metrics.service import metrics as metrics_service from polar.models import ( + Customer, Order, Organization, Product, @@ -121,7 +122,7 @@ def _date_to_datetime(date: date) -> datetime: async def _create_fixtures( save_fixture: SaveFixture, - user: User, + customer: Customer, organization: Organization, product_fixtures: dict[str, ProductFixture], subscription_fixtures: dict[str, SubscriptionFixture], @@ -139,7 +140,7 @@ async def _create_fixtures( subscription = await create_subscription( save_fixture, product=products[subscription_fixture["product"]], - user=user, + customer=customer, status=SubscriptionStatus.active, started_at=_date_to_datetime(subscription_fixture["started_at"]), ended_at=( @@ -158,7 +159,7 @@ async def _create_fixtures( order = await create_order( save_fixture, product=products[order_fixture["product"]], - user=user, + customer=customer, amount=order_fixture["amount"], created_at=_date_to_datetime(order_fixture["created_at"]), subscription=order_subscription, @@ -171,10 +172,10 @@ async def _create_fixtures( @pytest_asyncio.fixture async def fixtures( - save_fixture: SaveFixture, user: User, organization: Organization + save_fixture: SaveFixture, customer: Customer, organization: Organization ) -> tuple[dict[str, Product], dict[str, Subscription], dict[str, Order]]: return await _create_fixtures( - save_fixture, user, organization, PRODUCTS, SUBSCRIPTIONS, ORDERS + save_fixture, customer, organization, PRODUCTS, SUBSCRIPTIONS, ORDERS ) @@ -489,7 +490,7 @@ async def test_values_free_subscription( session: AsyncSession, auth_subject: AuthSubject[User], user_organization: UserOrganization, - user: User, + customer: Customer, organization: Organization, ) -> None: subscriptions: dict[str, SubscriptionFixture] = { @@ -504,7 +505,7 @@ async def test_values_free_subscription( }, } await _create_fixtures( - save_fixture, user, organization, PRODUCTS, subscriptions, {} + save_fixture, customer, organization, PRODUCTS, subscriptions, {} ) metrics = await metrics_service.get_metrics( @@ -577,6 +578,7 @@ async def test_values_subscription_canceled_during_interval( auth_subject: AuthSubject[User], user_organization: UserOrganization, user: User, + customer: Customer, organization: Organization, ) -> None: """ @@ -597,7 +599,7 @@ async def test_values_subscription_canceled_during_interval( } } await _create_fixtures( - save_fixture, user, organization, PRODUCTS, subscriptions, {} + save_fixture, customer, organization, PRODUCTS, subscriptions, {} ) metrics = await metrics_service.get_metrics( diff --git a/server/tests/transaction/conftest.py b/server/tests/transaction/conftest.py index f71873b04f..63d1e8809b 100644 --- a/server/tests/transaction/conftest.py +++ b/server/tests/transaction/conftest.py @@ -35,6 +35,7 @@ async def create_transaction( account: Account | None = None, payment_customer: Customer | None = None, payment_organization: Organization | None = None, + payment_user: User | None = None, type: TransactionType = TransactionType.balance, amount: int = 1000, account_currency: str = "eur", @@ -59,6 +60,7 @@ async def create_transaction( account=account, payment_customer=payment_customer, payment_organization=payment_organization, + payment_user=payment_user, pledge=pledge, issue_reward=issue_reward, order=order, @@ -193,12 +195,10 @@ async def account_transactions( @pytest_asyncio.fixture -async def user_transactions( - save_fixture: SaveFixture, customer: Customer -) -> list[Transaction]: +async def user_transactions(save_fixture: SaveFixture, user: User) -> list[Transaction]: return [ await create_transaction( - save_fixture, type=TransactionType.payment, customer=customer + save_fixture, type=TransactionType.payment, payment_user=user ), ] diff --git a/server/tests/transaction/service/test_payment.py b/server/tests/transaction/service/test_payment.py index 7d3cb43a88..cbe31100a4 100644 --- a/server/tests/transaction/service/test_payment.py +++ b/server/tests/transaction/service/test_payment.py @@ -8,7 +8,7 @@ from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService -from polar.models import Organization, Pledge, Transaction, User +from polar.models import Customer, Organization, Pledge, Transaction, User from polar.models.transaction import PaymentProcessor, TransactionType from polar.postgres import AsyncSession from polar.transaction.service.payment import ( # type: ignore[attr-defined] @@ -140,24 +140,22 @@ async def test_existing_transaction( pytest.param("normal", 4), ], ) - async def test_customer_user( + async def test_customer( self, session: AsyncSession, save_fixture: SaveFixture, pledge: Pledge, - user: User, + customer: Customer, stripe_service_mock: MagicMock, risk_level: str | None, risk_score: int | None, ) -> None: - user.stripe_customer_id = "STRIPE_CUSTOMER_ID" - await save_fixture(user) pledge.payment_id = "STRIPE_PAYMENT_ID" await save_fixture(pledge) stripe_balance_transaction = build_stripe_balance_transaction() stripe_charge = build_stripe_charge( - customer=user.stripe_customer_id, + customer=customer.stripe_customer_id, payment_intent=pledge.payment_id, balance_transaction=stripe_balance_transaction.id, risk_level=risk_level, @@ -176,8 +174,8 @@ async def test_customer_user( ) assert transaction.type == TransactionType.payment - assert transaction.customer_id == user.stripe_customer_id - assert transaction.payment_user == user + assert transaction.customer_id == customer.stripe_customer_id + assert transaction.payment_customer == customer assert transaction.payment_organization is None assert transaction.risk_level == risk_level assert transaction.risk_score == risk_score @@ -216,7 +214,7 @@ async def test_customer_organization( assert transaction.type == TransactionType.payment assert transaction.customer_id == organization.stripe_customer_id - assert transaction.payment_user is None + assert transaction.payment_customer is None assert transaction.payment_organization == organization async def test_not_existing_pledge( @@ -320,7 +318,7 @@ async def test_anonymous_pledge( assert transaction.type == TransactionType.payment assert transaction.pledge == pledge - assert transaction.payment_user == pledge.user + assert transaction.payment_customer == pledge.user assert transaction.payment_organization == pledge.by_organization async def test_tax_metadata( diff --git a/server/tests/transaction/service/test_transaction.py b/server/tests/transaction/service/test_transaction.py index 8e90eaacc9..20ababc1ab 100644 --- a/server/tests/transaction/service/test_transaction.py +++ b/server/tests/transaction/service/test_transaction.py @@ -130,7 +130,10 @@ async def test_filter_payment_user( session.expunge_all() results, count = await transaction_service.search( - session, user, payment_user_id=user.id, pagination=PaginationParams(1, 10) + session, + user, + payment_user_id=user.id, + pagination=PaginationParams(1, 10), ) assert count == len(user_transactions) From 71e8beda73903c3e95b00c280f1bcff84bc9f0a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 5 Dec 2024 17:27:04 +0100 Subject: [PATCH 13/47] server: add UserCustomer association table and use it in customer portal queries --- .../customer_portal/service/benefit_grant.py | 9 +++++- .../customer_portal/service/downloadables.py | 10 +++++-- server/polar/customer_portal/service/order.py | 18 ++++++++++-- .../customer_portal/service/subscription.py | 9 +++++- server/polar/license_key/service.py | 17 ++++++++--- server/polar/models/__init__.py | 2 ++ server/polar/models/user.py | 15 +++++++++- server/polar/models/user_customer.py | 29 +++++++++++++++++++ 8 files changed, 98 insertions(+), 11 deletions(-) create mode 100644 server/polar/models/user_customer.py diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py index f40afec1a4..9edd3c59bb 100644 --- a/server/polar/customer_portal/service/benefit_grant.py +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -17,6 +17,7 @@ Customer, Organization, User, + UserCustomer, ) from polar.models.benefit import BenefitType @@ -106,7 +107,13 @@ def _get_readable_benefit_grant_statement( ) if is_user(auth_subject): - raise NotImplementedError("TODO") + statement = statement.where( + BenefitGrant.customer_id.in_( + select(UserCustomer.customer_id).where( + UserCustomer.user_id == auth_subject.subject.id + ) + ) + ) elif is_customer(auth_subject): statement = statement.where( BenefitGrant.customer_id == auth_subject.subject.id diff --git a/server/polar/customer_portal/service/downloadables.py b/server/polar/customer_portal/service/downloadables.py index ba26d4bad3..ca212b5aad 100644 --- a/server/polar/customer_portal/service/downloadables.py +++ b/server/polar/customer_portal/service/downloadables.py @@ -18,7 +18,7 @@ from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceService from polar.kit.utils import utc_now -from polar.models import Benefit, Customer, User +from polar.models import Benefit, Customer, User, UserCustomer from polar.models.downloadable import Downloadable, DownloadableStatus from polar.models.file import File from polar.postgres import AsyncSession, sql @@ -244,7 +244,13 @@ def _get_base_query( ) if is_user(auth_subject): - raise NotImplementedError("TODO") + statement = statement.where( + Downloadable.customer_id.in_( + sql.select(UserCustomer.customer_id).where( + UserCustomer.user_id == auth_subject.subject.id + ) + ) + ) elif is_customer(auth_subject): statement = statement.where( Downloadable.customer_id == auth_subject.subject.id diff --git a/server/polar/customer_portal/service/order.py b/server/polar/customer_portal/service/order.py index c390ae87ab..688cd78194 100644 --- a/server/polar/customer_portal/service/order.py +++ b/server/polar/customer_portal/service/order.py @@ -13,7 +13,15 @@ from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.kit.sorting import Sorting -from polar.models import Customer, Order, Organization, Product, ProductPrice, User +from polar.models import ( + Customer, + Order, + Organization, + Product, + ProductPrice, + User, + UserCustomer, +) from polar.models.product_price import ProductPriceType @@ -149,7 +157,13 @@ def _get_readable_order_statement( ) if is_user(auth_subject): - raise NotImplementedError("TODO") + statement = statement.where( + Order.customer_id.in_( + select(UserCustomer.customer_id).where( + UserCustomer.user_id == auth_subject.subject.id + ) + ) + ) elif is_customer(auth_subject): customer = auth_subject.subject statement = statement.where(Order.customer_id == customer.id) diff --git a/server/polar/customer_portal/service/subscription.py b/server/polar/customer_portal/service/subscription.py index 43aca92c73..531c4373e2 100644 --- a/server/polar/customer_portal/service/subscription.py +++ b/server/polar/customer_portal/service/subscription.py @@ -24,6 +24,7 @@ ProductPriceFree, Subscription, User, + UserCustomer, ) from polar.models.product_price import ProductPriceType from polar.models.subscription import SubscriptionStatus @@ -305,7 +306,13 @@ def _get_readable_subscription_statement( statement = select(Subscription).where(Subscription.deleted_at.is_(None)) if is_user(auth_subject): - raise NotImplementedError("TODO") + statement = statement.where( + Subscription.customer_id.in_( + select(UserCustomer.customer_id).where( + UserCustomer.user_id == auth_subject.subject.id + ) + ) + ) elif is_customer(auth_subject): statement = statement.where( Subscription.customer_id == auth_subject.subject.id diff --git a/server/polar/license_key/service.py b/server/polar/license_key/service.py index 39eea3e81f..4c81578c3c 100644 --- a/server/polar/license_key/service.py +++ b/server/polar/license_key/service.py @@ -17,6 +17,7 @@ LicenseKeyActivation, Organization, User, + UserCustomer, UserOrganization, ) from polar.models.benefit import BenefitLicenseKeys @@ -486,12 +487,20 @@ def _get_select_base(self) -> Select[tuple[LicenseKey]]: def _get_select_customer_base( self, auth_subject: AuthSubject[User | Customer] ) -> Select[tuple[LicenseKey]]: - query = self._get_select_base() + statement = self._get_select_base() if is_user(auth_subject): - raise NotImplementedError("TODO") + statement = statement.where( + LicenseKey.customer_id.in_( + select(UserCustomer.customer_id).where( + UserCustomer.user_id == auth_subject.subject.id + ) + ) + ) elif is_customer(auth_subject): - query = query.where(LicenseKey.customer_id == auth_subject.subject.id) - return query + statement = statement.where( + LicenseKey.customer_id == auth_subject.subject.id + ) + return statement license_key = LicenseKeyService(LicenseKey) diff --git a/server/polar/models/__init__.py b/server/polar/models/__init__.py index d42f0cbac9..f0bac160b8 100644 --- a/server/polar/models/__init__.py +++ b/server/polar/models/__init__.py @@ -44,6 +44,7 @@ from .subscription import Subscription from .transaction import Transaction from .user import OAuthAccount, User +from .user_customer import UserCustomer from .user_notification import UserNotification from .user_organization import UserOrganization from .user_session import UserSession @@ -97,6 +98,7 @@ "Subscription", "Transaction", "User", + "UserCustomer", "UserNotification", "UserOrganization", "UserSession", diff --git a/server/polar/models/user.py b/server/polar/models/user.py index d7f1860d97..8441439a5f 100644 --- a/server/polar/models/user.py +++ b/server/polar/models/user.py @@ -1,7 +1,7 @@ import time from datetime import datetime from enum import StrEnum -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import UUID from sqlalchemy import ( @@ -16,6 +16,7 @@ func, ) from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from sqlalchemy.schema import Index, UniqueConstraint @@ -24,6 +25,10 @@ from .account import Account +if TYPE_CHECKING: + from .customer import Customer + from .user_customer import UserCustomer + class OAuthPlatform(StrEnum): # maximum allowed length is 32 chars @@ -110,6 +115,14 @@ def account(cls) -> Mapped[Account | None]: def oauth_accounts(cls) -> Mapped[list[OAuthAccount]]: return relationship(OAuthAccount, lazy="joined", back_populates="user") + @declared_attr + def user_customers(cls) -> Mapped[list["UserCustomer"]]: + return relationship("UserCustomer", lazy="raise", back_populates="user") + + customers: AssociationProxy[list["Customer"]] = association_proxy( + "user_customers", "customer" + ) + accepted_terms_of_service: Mapped[bool] = mapped_column( Boolean, nullable=False, diff --git a/server/polar/models/user_customer.py b/server/polar/models/user_customer.py new file mode 100644 index 0000000000..d6b96d171a --- /dev/null +++ b/server/polar/models/user_customer.py @@ -0,0 +1,29 @@ +from uuid import UUID + +from sqlalchemy import ForeignKey, Uuid +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship + +from polar.kit.db.models.base import RecordModel + +from .customer import Customer +from .user import User + + +class UserCustomer(RecordModel): + __tablename__ = "user_customers" + + user_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("users.id", ondelete="cascade"), nullable=False + ) + customer_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False + ) + + @declared_attr + def user(cls) -> Mapped[User]: + return relationship(User, lazy="raise", back_populates="user_customers") + + @declared_attr + def customer(cls) -> Mapped[Customer]: + # This is an association table, so eager loading makes sense + return relationship(Customer, lazy="joined") From 89e710828770581c02c763470882b791ec315945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 6 Dec 2024 14:04:55 +0100 Subject: [PATCH 14/47] server/customer: implement oauth account connection --- server/polar/benefit/benefits/discord.py | 50 ++++++- .../benefit/benefits/github_repository.py | 6 +- server/polar/customer_portal/auth.py | 8 ++ .../customer_portal/endpoints/__init__.py | 6 +- .../endpoints/oauth_accounts.py | 131 ++++++++++++++++++ server/polar/integrations/discord/service.py | 55 -------- server/polar/kit/jwt.py | 1 + server/polar/models/customer.py | 39 +++++- server/pyproject.toml | 2 +- server/uv.lock | 8 +- 10 files changed, 230 insertions(+), 76 deletions(-) create mode 100644 server/polar/customer_portal/endpoints/oauth_accounts.py diff --git a/server/polar/benefit/benefits/discord.py b/server/polar/benefit/benefits/discord.py index 7211b22da8..1356072690 100644 --- a/server/polar/benefit/benefits/discord.py +++ b/server/polar/benefit/benefits/discord.py @@ -2,16 +2,16 @@ import httpx import structlog +from httpx_oauth.clients.discord import DiscordOAuth2 from polar.auth.models import AuthSubject +from polar.config import settings from polar.integrations.discord.service import discord_bot as discord_bot_service -from polar.integrations.discord.service import ( - discord_customer as discord_customer_service, -) from polar.logging import Logger from polar.models import Customer, Organization, User from polar.models.benefit import BenefitDiscord, BenefitDiscordProperties from polar.models.benefit_grant import BenefitGrantDiscordProperties +from polar.models.customer import CustomerOAuthAccount, CustomerOAuthPlatform from .base import ( BenefitActionRequiredError, @@ -62,9 +62,7 @@ async def grant( "The customer needs to connect their Discord account" ) - oauth_account = await discord_customer_service.get_oauth_account( - self.session, customer, account_id - ) + oauth_account = await self._get_customer_oauth_account(customer, account_id) try: await discord_bot_service.add_member( @@ -81,7 +79,7 @@ async def grant( bound_logger.debug("Benefit granted") - # Store guild, and role an as it may change if the benefit is updated + # Store guild, and role as it may change if the benefit is updated return { **grant_properties, "guild_id": guild_id, @@ -168,3 +166,41 @@ async def validate_properties( ) return cast(BenefitDiscordProperties, properties) + + async def _get_customer_oauth_account( + self, customer: Customer, account_id: str + ) -> CustomerOAuthAccount: + oauth_account = customer.get_oauth_account( + account_id, CustomerOAuthPlatform.discord + ) + if oauth_account is None: + raise BenefitActionRequiredError( + "The customer needs to connect their Discord account" + ) + + if oauth_account.is_expired(): + if oauth_account.refresh_token is None: + raise BenefitActionRequiredError( + "The customer needs to reconnect their Discord account" + ) + + log.debug( + "Refresh Discord access token", + oauth_account_id=oauth_account.account_id, + customer_id=str(customer.id), + ) + client = DiscordOAuth2( + settings.DISCORD_CLIENT_ID, + settings.DISCORD_CLIENT_SECRET, + scopes=["identify", "email", "guilds.join"], + ) + refreshed_token_data = await client.refresh_token( + oauth_account.refresh_token + ) + oauth_account.access_token = refreshed_token_data["access_token"] + oauth_account.expires_at = refreshed_token_data["expires_at"] + oauth_account.refresh_token = refreshed_token_data["refresh_token"] + customer.set_oauth_account(oauth_account, CustomerOAuthPlatform.discord) + self.session.add(customer) + + return oauth_account diff --git a/server/polar/benefit/benefits/github_repository.py b/server/polar/benefit/benefits/github_repository.py index 74736e162a..fd3fad07e5 100644 --- a/server/polar/benefit/benefits/github_repository.py +++ b/server/polar/benefit/benefits/github_repository.py @@ -6,7 +6,6 @@ from polar.auth.models import AuthSubject, is_organization, is_user from polar.authz.service import AccessType, Authz -from polar.config import settings from polar.integrations.github import client as github from polar.integrations.github import types from polar.integrations.github_repository_benefit.service import ( @@ -19,6 +18,7 @@ BenefitGitHubRepositoryProperties, ) from polar.models.benefit_grant import BenefitGrantGitHubRepositoryProperties +from polar.models.customer import CustomerOAuthPlatform from polar.posthog import posthog from polar.repository.service import repository as repository_service @@ -66,7 +66,7 @@ async def grant( ) oauth_account = customer.get_oauth_account( - f"github:{settings.GITHUB_CLIENT_ID}:{account_id}" + account_id, CustomerOAuthPlatform.github ) if oauth_account is None or oauth_account.account_username is None: @@ -145,7 +145,7 @@ async def revoke( ) oauth_account = customer.get_oauth_account( - f"github:{settings.GITHUB_CLIENT_ID}:{account_id}" + account_id, CustomerOAuthPlatform.github ) if oauth_account is None or oauth_account.account_username is None: diff --git a/server/polar/customer_portal/auth.py b/server/polar/customer_portal/auth.py index 2fb276fa52..ab699a84a8 100644 --- a/server/polar/customer_portal/auth.py +++ b/server/polar/customer_portal/auth.py @@ -25,3 +25,11 @@ CustomerPortalWrite = Annotated[ AuthSubject[User | Customer], Depends(_CustomerPortalWrite) ] + +_CustomerPortalOAuthAccount = Authenticator( + required_scopes={Scope.web_default, Scope.customer_portal_write}, + allowed_subjects={Customer}, +) +CustomerPortalOAuthAccount = Annotated[ + AuthSubject[Customer], Depends(_CustomerPortalOAuthAccount) +] diff --git a/server/polar/customer_portal/endpoints/__init__.py b/server/polar/customer_portal/endpoints/__init__.py index c689438adf..5bf61b181b 100644 --- a/server/polar/customer_portal/endpoints/__init__.py +++ b/server/polar/customer_portal/endpoints/__init__.py @@ -3,13 +3,15 @@ from .benefit_grant import router as benefit_grant_router from .downloadables import router as downloadables_router from .license_keys import router as license_keys_router +from .oauth_accounts import router as oauth_accounts_router from .order import router as order_router from .subscription import router as subscription_router router = APIRouter(prefix="/customer-portal", tags=["customer_portal"]) router.include_router(benefit_grant_router) -router.include_router(order_router) -router.include_router(subscription_router) router.include_router(downloadables_router) router.include_router(license_keys_router) +router.include_router(oauth_accounts_router) +router.include_router(order_router) +router.include_router(subscription_router) diff --git a/server/polar/customer_portal/endpoints/oauth_accounts.py b/server/polar/customer_portal/endpoints/oauth_accounts.py new file mode 100644 index 0000000000..cb5ca932dc --- /dev/null +++ b/server/polar/customer_portal/endpoints/oauth_accounts.py @@ -0,0 +1,131 @@ +from typing import Any + +import structlog +from fastapi import Depends, Query, Request +from fastapi.responses import RedirectResponse +from httpx_oauth.clients.discord import DiscordOAuth2 +from httpx_oauth.clients.github import GitHubOAuth2 +from httpx_oauth.exceptions import GetProfileError +from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError + +from polar.config import settings +from polar.integrations.github.client import Forbidden +from polar.kit import jwt +from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url +from polar.logging import Logger +from polar.models.customer import CustomerOAuthAccount, CustomerOAuthPlatform +from polar.openapi import APITag +from polar.postgres import AsyncSession, get_db_session +from polar.routing import APIRouter + +from .. import auth + +router = APIRouter(prefix="/oauth-accounts", tags=["oauth-accounts", APITag.private]) + +log: Logger = structlog.get_logger() + + +OAUTH_CLIENTS: dict[CustomerOAuthPlatform, BaseOAuth2[Any]] = { + CustomerOAuthPlatform.github: GitHubOAuth2( + settings.GITHUB_CLIENT_ID, settings.GITHUB_CLIENT_SECRET + ), + CustomerOAuthPlatform.discord: DiscordOAuth2( + settings.DISCORD_CLIENT_ID, + settings.DISCORD_CLIENT_SECRET, + scopes=["identify", "email", "guilds.join"], + ), +} + + +@router.get("/authorize", name="customer_portal.oauth_accounts.authorize") +async def authorize( + request: Request, + return_to: ReturnTo, + auth_subject: auth.CustomerPortalOAuthAccount, + platform: CustomerOAuthPlatform = Query(...), +) -> RedirectResponse: + state = { + "customer_id": str(auth_subject.subject.id), + "platform": platform, + "return_to": return_to, + } + encoded_state = jwt.encode( + data=state, secret=settings.SECRET, type="customer_oauth" + ) + client = OAUTH_CLIENTS[platform] + authorization_url = await client.get_authorization_url( + redirect_uri=str(request.url_for("customer_portal.oauth_accounts.callback")), + state=encoded_state, + ) + return RedirectResponse(authorization_url, 303) + + +@router.get("/callback", name="customer_portal.oauth_accounts.callback") +async def callback( + request: Request, + auth_subject: auth.CustomerPortalOAuthAccount, + state: str, + code: str | None = None, + error: str | None = None, + session: AsyncSession = Depends(get_db_session), +) -> RedirectResponse: + try: + state_data = jwt.decode( + token=state, + secret=settings.SECRET, + type="customer_oauth", + ) + except jwt.DecodeError as e: + raise Forbidden("Invalid state") from e + + if str(auth_subject.subject.id) != state_data["customer_id"]: + raise Forbidden("Invalid state") + + return_to = state_data["return_to"] + platform = CustomerOAuthPlatform(state_data["platform"]) + + if code is None or error is not None: + redirect_url = get_safe_return_url( + add_query_parameters(return_to, error=error or "Failed to authorize.") + ) + return RedirectResponse(redirect_url, 303) + + try: + client = OAUTH_CLIENTS[platform] + oauth2_token_data = await client.get_access_token( + code, str(request.url_for("customer_portal.oauth_accounts.callback")) + ) + except GetAccessTokenError as e: + redirect_url = get_safe_return_url( + add_query_parameters( + return_to, error="Failed to get access token. Please try again later." + ) + ) + log.error("Failed to get access token", error=str(e)) + return RedirectResponse(redirect_url, 303) + + try: + profile = await client.get_profile(oauth2_token_data["access_token"]) + except GetProfileError as e: + redirect_url = get_safe_return_url( + add_query_parameters( + return_to, + error="Failed to get profile information. Please try again later.", + ) + ) + log.error("Failed to get account ID", error=str(e)) + return RedirectResponse(redirect_url, 303) + + oauth_account = CustomerOAuthAccount( + access_token=oauth2_token_data["access_token"], + expires_at=oauth2_token_data["expires_at"], + refresh_token=oauth2_token_data["refresh_token"], + account_id=platform.get_account_id(profile), + account_username=platform.get_account_username(profile), + ) + + customer = auth_subject.subject + customer.set_oauth_account(oauth_account, platform) + session.add(customer) + + return RedirectResponse(state_data["return_to"]) diff --git a/server/polar/integrations/discord/service.py b/server/polar/integrations/discord/service.py index e730ffdf3f..864cadb219 100644 --- a/server/polar/integrations/discord/service.py +++ b/server/polar/integrations/discord/service.py @@ -5,7 +5,6 @@ from polar.exceptions import PolarError from polar.logging import Logger from polar.models import Customer, OAuthAccount, User -from polar.models.customer import CustomerOAuthAccount from polar.models.user import OAuthPlatform from polar.postgres import AsyncSession @@ -126,59 +125,6 @@ async def get_oauth_account( return account -class DiscordCustomerService: - async def create_oauth_account( - self, session: AsyncSession, customer: Customer, oauth2_token_data: OAuth2Token - ) -> CustomerOAuthAccount: - access_token = oauth2_token_data["access_token"] - - client = DiscordClient("Bearer", access_token) - data = await client.get_me() - - account_id = data["id"] - oauth_account = CustomerOAuthAccount( - access_token=access_token, - expires_at=oauth2_token_data["expires_at"], - refresh_token=oauth2_token_data["refresh_token"], - account_id=data["id"], - ) - customer.set_oauth_account(self._get_account_key(account_id), oauth_account) - session.add(customer) - - return oauth_account - - async def get_oauth_account( - self, session: AsyncSession, customer: Customer, account_id: str - ) -> CustomerOAuthAccount: - account_key = self._get_account_key(account_id) - oauth_account = customer.get_oauth_account(account_key) - if oauth_account is None: - raise DiscordCustomerAccountDoesNotExist(customer, account_id) - - if oauth_account.is_expired(): - if oauth_account.refresh_token is None: - raise DiscordCustomerExpiredAccessToken(customer, account_id) - - log.debug( - "Refresh Discord access token", - oauth_account_id=oauth_account.account_id, - customer_id=str(customer.id), - ) - refreshed_token_data = await oauth.user_client.refresh_token( - oauth_account.refresh_token - ) - oauth_account.access_token = refreshed_token_data["access_token"] - oauth_account.expires_at = refreshed_token_data["expires_at"] - oauth_account.refresh_token = refreshed_token_data["refresh_token"] - customer.set_oauth_account(account_key, oauth_account) - session.add(customer) - - return oauth_account - - def _get_account_key(self, account_id: str) -> str: - return f"discord:{oauth.user_client.client_id}:{account_id}" - - class DiscordBotService: async def get_guild(self, id: str) -> DiscordGuild: guild = await bot_client.get_guild(id) @@ -239,5 +185,4 @@ async def is_bot_role_above_role(self, guild_id: str, role_id: str) -> bool: discord_user = DiscordUserService() -discord_customer = DiscordCustomerService() discord_bot = DiscordBotService() diff --git a/server/polar/kit/jwt.py b/server/polar/kit/jwt.py index 9c90a1dfef..0268282081 100644 --- a/server/polar/kit/jwt.py +++ b/server/polar/kit/jwt.py @@ -23,6 +23,7 @@ def create_expiration_dt(seconds: int) -> datetime: "discord_guild_token", "auth", "github_repository_benefit_oauth", + "customer_oauth", ] diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index 5544cc640e..10cb812747 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -1,5 +1,6 @@ import dataclasses import time +from enum import StrEnum from typing import TYPE_CHECKING, Any from uuid import UUID @@ -24,6 +25,28 @@ from .organization import Organization +class CustomerOAuthPlatform(StrEnum): + github = "github" + discord = "discord" + + def get_account_key(self, account_id: str) -> str: + return f"{self.value}:{account_id}" + + def get_account_id(self, data: dict[str, Any]) -> str: + if self == CustomerOAuthPlatform.github: + return data["id"] + if self == CustomerOAuthPlatform.discord: + return data["id"] + raise NotImplementedError() + + def get_account_username(self, data: dict[str, Any]) -> str: + if self == CustomerOAuthPlatform.github: + return data["login"] + if self == CustomerOAuthPlatform.discord: + return data["username"] + raise NotImplementedError() + + @dataclasses.dataclass class CustomerOAuthAccount: access_token: str @@ -77,17 +100,25 @@ class Customer(MetadataMixin, RecordModel): def organization(cls) -> Mapped["Organization"]: return relationship("Organization", lazy="raise") - def get_oauth_account(self, account_key: str) -> CustomerOAuthAccount | None: - oauth_account_data = self._oauth_accounts.get(account_key) + def get_oauth_account( + self, account_id: str, platform: CustomerOAuthPlatform + ) -> CustomerOAuthAccount | None: + oauth_account_data = self._oauth_accounts.get( + platform.get_account_key(account_id) + ) if oauth_account_data is None: return None return CustomerOAuthAccount(**oauth_account_data) def set_oauth_account( - self, account_key: str, oauth_account: CustomerOAuthAccount + self, oauth_account: CustomerOAuthAccount, platform: CustomerOAuthPlatform ) -> None: + account_key = platform.get_account_key(oauth_account.account_id) self._oauth_accounts[account_key] = dataclasses.asdict(oauth_account) - def remove_oauth_account(self, account_key: str) -> None: + def remove_oauth_account( + self, account_id: str, platform: CustomerOAuthPlatform + ) -> None: + account_key = platform.get_account_key(account_id) self._oauth_accounts.pop(account_key, None) diff --git a/server/pyproject.toml b/server/pyproject.toml index 7af08690f8..fd9af05887 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "resend>=2.4.0", "python-multipart>=0.0.12", "safe-redirect-url>=0.1.1", - "httpx-oauth>=0.15.1", + "httpx-oauth>=0.16.0", "httpx>=0.23.0", "pydantic-settings>=2.5.2", "email-validator>=2.1.0.post1", diff --git a/server/uv.lock b/server/uv.lock index f1dfc8fd2b..0448287a9b 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -613,14 +613,14 @@ wheels = [ [[package]] name = "httpx-oauth" -version = "0.15.1" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6d/08/b2e0f360e84994ec5c67473cef7198d22ccb2215e5068c0a21182d22baaa/httpx_oauth-0.15.1.tar.gz", hash = "sha256:4094cf0938fc7252b5f5dfd62cd1ab5aee2fcb6734e621942ee17d1af4806b74", size = 41302 } +sdist = { url = "https://files.pythonhosted.org/packages/1d/af/e4476044977493251c8d33e0400ffdf1bf09bd91b3fc71f84ba92b2bfdd2/httpx_oauth-0.16.0.tar.gz", hash = "sha256:5ce4696e4c9572711fb558808f7436afd7712db825c56b0f4f430e16c649f136", size = 41728 } wheels = [ - { url = "https://files.pythonhosted.org/packages/32/ac/7ed132cb22ed9ca1be2e837834cbe1786d79375b9c6a6f6a595aa1629ec3/httpx_oauth-0.15.1-py3-none-any.whl", hash = "sha256:89b45f250e93e42bbe9631adf349cab0e3d3ced958c07e06651735198d1bdf00", size = 37304 }, + { url = "https://files.pythonhosted.org/packages/0a/41/1a082687954efe61a77480c74ead8937faf6f3fcb4b7b24cbd56492769ff/httpx_oauth-0.16.0-py3-none-any.whl", hash = "sha256:4394a5e60cb66fd6e14031d41be8e4060580c553422a641b624ade0f33d0df04", size = 38045 }, ] [[package]] @@ -1279,7 +1279,7 @@ requires-dist = [ { name = "githubkit", specifier = "==0.11.14" }, { name = "greenlet", specifier = ">=3.1.1" }, { name = "httpx", specifier = ">=0.23.0" }, - { name = "httpx-oauth", specifier = ">=0.15.1" }, + { name = "httpx-oauth", specifier = ">=0.16.0" }, { name = "ipinfo-db", specifier = ">=0.0.4" }, { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "jinja2", specifier = ">=3.1.2" }, From 1330499f3d8c521fbb2bc098d7cf43bf54caaab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 6 Dec 2024 14:25:19 +0100 Subject: [PATCH 15/47] server/customer_session: implement customer authentication mechanism --- server/polar/auth/dependencies.py | 22 ++++++++- server/polar/auth/models.py | 1 + server/polar/config.py | 3 ++ server/polar/customer_session/__init__.py | 0 server/polar/customer_session/dependencies.py | 23 +++++++++ server/polar/customer_session/service.py | 47 +++++++++++++++++++ server/polar/customer_session/tasks.py | 9 ++++ server/polar/models/__init__.py | 2 + server/polar/models/customer_session.py | 31 ++++++++++++ server/polar/tasks.py | 2 + 10 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 server/polar/customer_session/__init__.py create mode 100644 server/polar/customer_session/dependencies.py create mode 100644 server/polar/customer_session/service.py create mode 100644 server/polar/customer_session/tasks.py create mode 100644 server/polar/models/customer_session.py diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index e59a9ce2e1..53c5dc38d9 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -5,8 +5,9 @@ from makefun import with_signature from polar.auth.scope import RESERVED_SCOPES, Scope +from polar.customer_session.dependencies import get_optional_customer_session_token from polar.exceptions import NotPermitted, Unauthorized -from polar.models import OAuth2Token, PersonalAccessToken, UserSession +from polar.models import CustomerSession, OAuth2Token, PersonalAccessToken, UserSession from polar.oauth2.dependencies import get_optional_token from polar.oauth2.exceptions import InsufficientScopeError, InvalidTokenError from polar.personal_access_token.dependencies import get_optional_personal_access_token @@ -37,6 +38,9 @@ async def get_auth_subject( personal_access_token_credentials: tuple[ PersonalAccessToken | None, bool ] = Depends(get_optional_personal_access_token), + customer_session_credentials: tuple[CustomerSession | None, bool] = Depends( + get_optional_customer_session_token + ), ) -> AuthSubject[Subject]: # Web session if user_session is not None: @@ -54,6 +58,7 @@ async def get_auth_subject( personal_access_token, personal_access_token_authorization_set = ( personal_access_token_credentials ) + customer_session, customer_session_authorization_set = customer_session_credentials if oauth2_token: return AuthSubject( @@ -67,7 +72,20 @@ async def get_auth_subject( AuthMethod.PERSONAL_ACCESS_TOKEN, ) - if oauth2_authorization_set or personal_access_token_authorization_set: + if customer_session: + return AuthSubject( + customer_session.customer, + {Scope.customer_portal_write}, + AuthMethod.CUSTOMER_SESSION_TOKEN, + ) + + if any( + ( + oauth2_authorization_set, + personal_access_token_authorization_set, + customer_session_authorization_set, + ) + ): raise InvalidTokenError() return AuthSubject(Anonymous(), set(), AuthMethod.NONE) diff --git a/server/polar/auth/models.py b/server/polar/auth/models.py index e27cd6a52c..707b3a5a55 100644 --- a/server/polar/auth/models.py +++ b/server/polar/auth/models.py @@ -18,6 +18,7 @@ class AuthMethod(Enum): COOKIE = auto() PERSONAL_ACCESS_TOKEN = auto() OAUTH2_ACCESS_TOKEN = auto() + CUSTOMER_SESSION_TOKEN = auto() S = TypeVar("S", bound=Subject, covariant=True) diff --git a/server/polar/config.py b/server/polar/config.py index c6f5f5aa1f..1f8d70eb45 100644 --- a/server/polar/config.py +++ b/server/polar/config.py @@ -62,6 +62,9 @@ class Settings(BaseSettings): USER_SESSION_COOKIE_KEY: str = "polar_session" USER_SESSION_COOKIE_DOMAIN: str = "127.0.0.1" + # Customer session + CUSTOMER_SESSION_TTL: timedelta = timedelta(hours=1) + # Magic link MAGIC_LINK_TTL_SECONDS: int = 60 * 30 # 30 minutes diff --git a/server/polar/customer_session/__init__.py b/server/polar/customer_session/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/polar/customer_session/dependencies.py b/server/polar/customer_session/dependencies.py new file mode 100644 index 0000000000..d45cd4f225 --- /dev/null +++ b/server/polar/customer_session/dependencies.py @@ -0,0 +1,23 @@ +from fastapi import Depends +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from polar.models import CustomerSession +from polar.postgres import AsyncSession, get_db_session + +from .service import customer_session as customer_session_service + +auth_header_scheme = HTTPBearer(scheme_name="customer_session", auto_error=False) + + +async def get_optional_customer_session_token( + auth_header: HTTPAuthorizationCredentials | None = Depends(auth_header_scheme), + session: AsyncSession = Depends(get_db_session), +) -> tuple[CustomerSession | None, bool]: + if auth_header is None: + return None, False + + token = await customer_session_service.get_by_token( + session, auth_header.credentials + ) + + return token, True diff --git a/server/polar/customer_session/service.py b/server/polar/customer_session/service.py new file mode 100644 index 0000000000..6f4939db82 --- /dev/null +++ b/server/polar/customer_session/service.py @@ -0,0 +1,47 @@ +from sqlalchemy import delete, select + +from polar.config import settings +from polar.kit.crypto import generate_token_hash_pair, get_token_hash +from polar.kit.services import ResourceServiceReader +from polar.kit.utils import utc_now +from polar.models import Customer, CustomerSession +from polar.postgres import AsyncSession + +CUSTOMER_SESSION_TOKEN_PREFIX = "polar_cst_" + + +class CustomerSessionService(ResourceServiceReader[CustomerSession]): + async def create_customer_session( + self, session: AsyncSession, customer: Customer + ) -> tuple[str, CustomerSession]: + token, token_hash = generate_token_hash_pair( + secret=settings.SECRET, prefix=CUSTOMER_SESSION_TOKEN_PREFIX + ) + customer_session = CustomerSession(token=token_hash, customer=customer) + session.add(customer_session) + await session.flush() + + return token, customer_session + + async def get_by_token( + self, session: AsyncSession, token: str, *, expired: bool = False + ) -> CustomerSession | None: + token_hash = get_token_hash(token, secret=settings.SECRET) + statement = select(CustomerSession).where( + CustomerSession.token == token_hash, + CustomerSession.deleted_at.is_(None), + ) + if not expired: + statement = statement.where(CustomerSession.expires_at > utc_now()) + + result = await session.execute(statement) + return result.unique().scalar_one_or_none() + + async def delete_expired(self, session: AsyncSession) -> None: + statement = delete(CustomerSession).where( + CustomerSession.expires_at < utc_now() + ) + await session.execute(statement) + + +customer_session = CustomerSessionService(CustomerSession) diff --git a/server/polar/customer_session/tasks.py b/server/polar/customer_session/tasks.py new file mode 100644 index 0000000000..31f140bf71 --- /dev/null +++ b/server/polar/customer_session/tasks.py @@ -0,0 +1,9 @@ +from polar.worker import AsyncSessionMaker, CronTrigger, JobContext, task + +from .service import customer_session as customer_session_service + + +@task("customer_session.delete_expired", cron_trigger=CronTrigger(hour=0, minute=0)) +async def customer_session_delete_expired(ctx: JobContext) -> None: + async with AsyncSessionMaker(ctx) as session: + await customer_session_service.delete_expired(session) diff --git a/server/polar/models/__init__.py b/server/polar/models/__init__.py index f0bac160b8..afe43110d4 100644 --- a/server/polar/models/__init__.py +++ b/server/polar/models/__init__.py @@ -8,6 +8,7 @@ from .checkout_link import CheckoutLink from .custom_field import CustomField from .customer import Customer +from .customer_session import CustomerSession from .discount import Discount from .discount_product import DiscountProduct from .discount_redemption import DiscountRedemption @@ -62,6 +63,7 @@ "Checkout", "CheckoutLink", "Customer", + "CustomerSession", "CustomField", "Discount", "DiscountProduct", diff --git a/server/polar/models/customer_session.py b/server/polar/models/customer_session.py new file mode 100644 index 0000000000..6794ddfdf5 --- /dev/null +++ b/server/polar/models/customer_session.py @@ -0,0 +1,31 @@ +from datetime import datetime +from uuid import UUID + +from sqlalchemy import CHAR, TIMESTAMP, ForeignKey, Uuid +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship + +from polar.config import settings +from polar.kit.db.models.base import RecordModel +from polar.kit.utils import utc_now +from polar.models.customer import Customer + + +def get_expires_at() -> datetime: + return utc_now() + settings.CUSTOMER_SESSION_TTL + + +class CustomerSession(RecordModel): + __tablename__ = "customer_sessions" + + token: Mapped[str] = mapped_column(CHAR(64), unique=True, nullable=False) + expires_at: Mapped[datetime] = mapped_column( + TIMESTAMP(timezone=True), nullable=False, index=True, default=get_expires_at + ) + + customer_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False + ) + + @declared_attr + def customer(cls) -> Mapped[Customer]: + return relationship(Customer, lazy="joined") diff --git a/server/polar/tasks.py b/server/polar/tasks.py index 734eee122d..bb1ca7db3b 100644 --- a/server/polar/tasks.py +++ b/server/polar/tasks.py @@ -2,6 +2,7 @@ from polar.auth import tasks as auth from polar.benefit import tasks as benefit from polar.checkout import tasks as checkout +from polar.customer_session import tasks as customer_session from polar.eventstream import tasks as eventstream from polar.integrations.github import tasks as github from polar.integrations.loops import tasks as loops @@ -21,6 +22,7 @@ "auth", "benefit", "checkout", + "customer_session", "eventstream", "github", "loops", From d2322ac989d75b4e36c815c6d6638eda2520f6cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 6 Dec 2024 14:42:00 +0100 Subject: [PATCH 16/47] server/checkout: generate a customer session after confirmed checkotu --- server/polar/checkout/endpoints.py | 3 ++- server/polar/checkout/schemas.py | 12 ++++++++++++ server/polar/checkout/service.py | 15 +++++++++++---- server/polar/models/checkout.py | 8 ++++++++ server/tests/checkout/test_endpoints.py | 3 +++ server/tests/checkout/test_service.py | 8 ++++++++ 6 files changed, 44 insertions(+), 5 deletions(-) diff --git a/server/polar/checkout/endpoints.py b/server/polar/checkout/endpoints.py index dc3982eaf5..e95fd72774 100644 --- a/server/polar/checkout/endpoints.py +++ b/server/polar/checkout/endpoints.py @@ -25,6 +25,7 @@ CheckoutCreate, CheckoutCreatePublic, CheckoutPublic, + CheckoutPublicConfirmed, CheckoutUpdate, CheckoutUpdatePublic, ) @@ -212,7 +213,7 @@ async def client_update( @router.post( "/client/{client_secret}/confirm", - response_model=CheckoutPublic, + response_model=CheckoutPublicConfirmed, summary="Confirm Checkout Session from Client", responses={ 200: {"description": "Checkout session confirmed."}, diff --git a/server/polar/checkout/schemas.py b/server/polar/checkout/schemas.py index b57d337470..a6896e96c1 100644 --- a/server/polar/checkout/schemas.py +++ b/server/polar/checkout/schemas.py @@ -416,3 +416,15 @@ class CheckoutPublic(CheckoutBase): discount: CheckoutDiscount | None organization: Organization attached_custom_fields: list[AttachedCustomField] + + +class CheckoutPublicConfirmed(CheckoutPublic): + """ + Checkout session data retrieved using the client secret after confirmation. + + It contains a customer session token to retrieve order information + right after the checkout. + """ + + status: Literal[CheckoutStatus.confirmed] + customer_session_token: str diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index b8993e91a1..eef172c7e7 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -23,6 +23,7 @@ from polar.config import settings from polar.custom_field.data import validate_custom_field_data from polar.customer.service import customer as customer_service +from polar.customer_session.service import customer_session as customer_session_service from polar.discount.service import DiscountNotRedeemableError from polar.discount.service import discount as discount_service from polar.enums import PaymentProcessor @@ -54,10 +55,7 @@ UserOrganization, ) from polar.models.checkout import CheckoutStatus -from polar.models.product_price import ( - ProductPriceAmountType, - ProductPriceFree, -) +from polar.models.product_price import ProductPriceAmountType, ProductPriceFree from polar.models.webhook_endpoint import WebhookEventType from polar.organization.service import organization as organization_service from polar.postgres import AsyncSession @@ -815,6 +813,15 @@ async def _confirm_inner( await self._after_checkout_updated(session, checkout) + assert checkout.customer is not None + ( + customer_session_token, + _, + ) = await customer_session_service.create_customer_session( + session, checkout.customer + ) + checkout.customer_session_token = customer_session_token + return checkout async def handle_stripe_success( diff --git a/server/polar/models/checkout.py b/server/polar/models/checkout.py index 4ef3c42621..b682152117 100644 --- a/server/polar/models/checkout.py +++ b/server/polar/models/checkout.py @@ -217,6 +217,14 @@ def is_payment_form_required(self) -> bool: def url(self) -> str: return settings.generate_frontend_url(f"/checkout/{self.client_secret}") + @property + def customer_session_token(self) -> str | None: + return getattr(self, "_customer_session_token", None) + + @customer_session_token.setter + def customer_session_token(self, value: str) -> None: + self._customer_session_token = value + attached_custom_fields: AssociationProxy[Sequence["AttachedCustomFieldMixin"]] = ( association_proxy("product", "attached_custom_fields") ) diff --git a/server/tests/checkout/test_endpoints.py b/server/tests/checkout/test_endpoints.py index d65dde8fc4..81f4d896ad 100644 --- a/server/tests/checkout/test_endpoints.py +++ b/server/tests/checkout/test_endpoints.py @@ -316,3 +316,6 @@ async def test_valid( ) assert response.status_code == 200 + + json = response.json() + assert "customer_session_token" in json diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index 75743e2571..40271c3c1f 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -31,6 +31,7 @@ PaymentRequired, ) from polar.checkout.service import checkout as checkout_service +from polar.customer_session.service import customer_session as customer_session_service from polar.discount.service import discount as discount_service from polar.enums import PaymentProcessor from polar.exceptions import PolarRequestValidationError @@ -1895,6 +1896,13 @@ async def test_valid_stripe( **expected_tax_metadata, } + assert checkout.customer_session_token is not None + customer_session = await customer_session_service.get_by_token( + session, checkout.customer_session_token + ) + assert customer_session is not None + assert customer_session.customer == checkout.customer + @pytest.mark.parametrize( "customer_billing_address,expected_tax_metadata", [ From 8a92227f3f5cfdb38246f124ac75b437d0b45677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 9 Dec 2024 11:38:16 +0100 Subject: [PATCH 17/47] server/auth: dynamically add the customer session dependency so it's enabled only when applicable --- server/polar/auth/dependencies.py | 82 ++++++++++++++++++++++++++----- server/tests/fixtures/base.py | 6 +-- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index 53c5dc38d9..1d9952d162 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -1,5 +1,6 @@ +from collections.abc import Awaitable, Callable from inspect import Parameter, Signature -from typing import Annotated +from typing import Annotated, Any from fastapi import Depends, Request, Security from makefun import with_signature @@ -7,7 +8,13 @@ from polar.auth.scope import RESERVED_SCOPES, Scope from polar.customer_session.dependencies import get_optional_customer_session_token from polar.exceptions import NotPermitted, Unauthorized -from polar.models import CustomerSession, OAuth2Token, PersonalAccessToken, UserSession +from polar.models import ( + Customer, + CustomerSession, + OAuth2Token, + PersonalAccessToken, + UserSession, +) from polar.oauth2.dependencies import get_optional_token from polar.oauth2.exceptions import InsufficientScopeError, InvalidTokenError from polar.personal_access_token.dependencies import get_optional_personal_access_token @@ -32,15 +39,14 @@ async def get_user_session( return await auth_service.authenticate(session, request) -async def get_auth_subject( - user_session: UserSession | None = Depends(get_user_session), - oauth2_credentials: tuple[OAuth2Token | None, bool] = Depends(get_optional_token), - personal_access_token_credentials: tuple[ - PersonalAccessToken | None, bool - ] = Depends(get_optional_personal_access_token), - customer_session_credentials: tuple[CustomerSession | None, bool] = Depends( - get_optional_customer_session_token +async def _get_auth_subject( + user_session: UserSession | None = None, + oauth2_credentials: tuple[OAuth2Token | None, bool] = (None, False), + personal_access_token_credentials: tuple[PersonalAccessToken | None, bool] = ( + None, + False, ), + customer_session_credentials: tuple[CustomerSession | None, bool] = (None, False), ) -> AuthSubject[Subject]: # Web session if user_session is not None: @@ -91,11 +97,59 @@ async def get_auth_subject( return AuthSubject(Anonymous(), set(), AuthMethod.NONE) +_auth_subject_factory_cache: dict[ + frozenset[SubjectType], Callable[..., Awaitable[AuthSubject[Subject]]] +] = {} + + +def _get_auth_subject_factory( + allowed_subjects: frozenset[SubjectType], +) -> Callable[..., Awaitable[AuthSubject[Subject]]]: + if allowed_subjects in _auth_subject_factory_cache: + return _auth_subject_factory_cache[allowed_subjects] + + parameters: list[Parameter] = [ + Parameter( + name="user_session", + kind=Parameter.KEYWORD_ONLY, + default=Depends(get_user_session), + ), + Parameter( + name="oauth2_credentials", + kind=Parameter.KEYWORD_ONLY, + default=Depends(get_optional_token), + ), + Parameter( + name="personal_access_token_credentials", + kind=Parameter.KEYWORD_ONLY, + default=Depends(get_optional_personal_access_token), + ), + ] + if Customer in allowed_subjects: + parameters.append( + Parameter( + name="customer_session_credentials", + kind=Parameter.KEYWORD_ONLY, + default=Depends(get_optional_customer_session_token), + ) + ) + + signature = Signature(parameters) + + @with_signature(signature) + async def get_auth_subject(**kwargs: Any) -> AuthSubject[Subject]: + return await _get_auth_subject(**kwargs) + + _auth_subject_factory_cache[allowed_subjects] = get_auth_subject + + return get_auth_subject + + class _Authenticator: def __init__( self, *, - allowed_subjects: set[SubjectType], + allowed_subjects: frozenset[SubjectType], required_scopes: set[Scope] | None = None, ) -> None: self.allowed_subjects = allowed_subjects @@ -149,13 +203,15 @@ def Authenticator( By doing so, we can dynamically inject the required scopes into FastAPI dependency, so they are properrly detected by the OpenAPI generator. """ + allowed_subjects_frozen = frozenset(allowed_subjects) + parameters: list[Parameter] = [ Parameter(name="self", kind=Parameter.POSITIONAL_OR_KEYWORD), Parameter( name="auth_subject", kind=Parameter.POSITIONAL_OR_KEYWORD, default=Security( - get_auth_subject, + _get_auth_subject_factory(allowed_subjects_frozen), scopes=sorted( [ s.value @@ -176,7 +232,7 @@ async def __call__( return await super().__call__(auth_subject) return _AuthenticatorSignature( - allowed_subjects=allowed_subjects, required_scopes=required_scopes + allowed_subjects=allowed_subjects_frozen, required_scopes=required_scopes ) diff --git a/server/tests/fixtures/base.py b/server/tests/fixtures/base.py index 3def1955f9..c05ae80af3 100644 --- a/server/tests/fixtures/base.py +++ b/server/tests/fixtures/base.py @@ -6,7 +6,7 @@ from httpx import AsyncClient from polar.app import app -from polar.auth.dependencies import get_auth_subject +from polar.auth.dependencies import _auth_subject_factory_cache from polar.auth.models import AuthSubject, Subject from polar.checkout.ip_geolocation import _get_client_dependency from polar.postgres import AsyncSession, get_db_session @@ -22,8 +22,9 @@ async def client( ) -> AsyncGenerator[AsyncClient, None]: app.dependency_overrides[get_db_session] = lambda: session app.dependency_overrides[get_redis] = lambda: redis - app.dependency_overrides[get_auth_subject] = lambda: auth_subject app.dependency_overrides[_get_client_dependency] = lambda: None + for auth_subject_getter in _auth_subject_factory_cache.values(): + app.dependency_overrides[auth_subject_getter] = lambda: auth_subject request_hooks = [] @@ -52,4 +53,3 @@ async def expunge_hook(request: Any) -> None: yield client app.dependency_overrides.pop(get_db_session) - app.dependency_overrides.pop(get_auth_subject) From 56cf8f6b477fdf2af65d5e865689c629826f3646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 9 Dec 2024 13:15:55 +0100 Subject: [PATCH 18/47] server/customer_portal: add API to update grants --- .../endpoints/benefit_grant.py | 36 +++++++++- .../customer_portal/schemas/benefit_grant.py | 66 +++++++++++++++-- .../customer_portal/service/benefit_grant.py | 72 ++++++++++++++++++- 3 files changed, 166 insertions(+), 8 deletions(-) diff --git a/server/polar/customer_portal/endpoints/benefit_grant.py b/server/polar/customer_portal/endpoints/benefit_grant.py index 12d96faccb..2cf22e42a2 100644 --- a/server/polar/customer_portal/endpoints/benefit_grant.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -3,7 +3,7 @@ from fastapi import Depends, Path, Query from pydantic import UUID4 -from polar.exceptions import ResourceNotFound +from polar.exceptions import NotPermitted, ResourceNotFound from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import ListResource, PaginationParamsQuery from polar.kit.schemas import MultipleQueryFilter @@ -17,7 +17,7 @@ from .. import auth from ..schemas.benefit_grant import BenefitGrant as BenefitGrantSchema -from ..schemas.benefit_grant import BenefitGrantAdapter +from ..schemas.benefit_grant import BenefitGrantAdapter, BenefitGrantUpdate from ..service.benefit_grant import CustomerBenefitGrantSortProperty from ..service.benefit_grant import ( customer_benefit_grant as customer_benefit_grant_service, @@ -100,3 +100,35 @@ async def get( raise ResourceNotFound() return benefit_grant + + +@router.get( + "/{id}", + summary="Update Benefit Grant", + response_model=BenefitGrantSchema, + responses={ + 200: {"description": "Benefit grant updated."}, + 403: { + "description": "The benefit grant is revoked and cannot be updated.", + "model": NotPermitted.schema(), + }, + 404: BenefitGrantNotFound, + }, +) +async def update( + id: BenefitGrantID, + benefit_grant_update: BenefitGrantUpdate, + auth_subject: auth.CustomerPortalWrite, + session: AsyncSession = Depends(get_db_session), +) -> BenefitGrant: + """Update a benefit grant for the authenticated customer or user.""" + benefit_grant = await customer_benefit_grant_service.get_by_id( + session, auth_subject, id + ) + + if benefit_grant is None: + raise ResourceNotFound() + + return await customer_benefit_grant_service.update( + session, benefit_grant, benefit_grant_update + ) diff --git a/server/polar/customer_portal/schemas/benefit_grant.py b/server/polar/customer_portal/schemas/benefit_grant.py index 4418c8f3b5..c6d5985479 100644 --- a/server/polar/customer_portal/schemas/benefit_grant.py +++ b/server/polar/customer_portal/schemas/benefit_grant.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Annotated +from typing import Annotated, Literal, TypedDict -from pydantic import UUID4, TypeAdapter +from pydantic import UUID4, Discriminator, TypeAdapter from polar.benefit.schemas import ( BenefitAdsSubscriber, @@ -12,7 +12,8 @@ BenefitLicenseKeysSubscriber, BenefitSubscriber, ) -from polar.kit.schemas import IDSchema, MergeJSONSchema, TimestampedSchema +from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema +from polar.models.benefit import BenefitType from polar.models.benefit_grant import ( BenefitGrantAdsProperties, BenefitGrantCustomProperties, @@ -21,6 +22,7 @@ BenefitGrantGitHubRepositoryProperties, BenefitGrantLicenseKeysProperties, ) +from polar.models.customer import CustomerOAuthPlatform class BenefitGrantBase(IDSchema, TimestampedSchema): @@ -65,7 +67,7 @@ class BenefitGrantAds(BenefitGrantBase): properties: BenefitGrantAdsProperties -class BenefitGrantCustomer(BenefitGrantBase): +class BenefitGrantCustom(BenefitGrantBase): benefit: BenefitCustomSubscriber properties: BenefitGrantCustomProperties @@ -76,7 +78,61 @@ class BenefitGrantCustomer(BenefitGrantBase): | BenefitGrantDownloadables | BenefitGrantLicenseKeys | BenefitGrantAds - | BenefitGrantCustomer, + | BenefitGrantCustom, MergeJSONSchema({"title": "BenefitGrant"}), ] BenefitGrantAdapter: TypeAdapter[BenefitGrant] = TypeAdapter(BenefitGrant) + + +class BenefitGrantUpdateBase(Schema): + benefit_type: BenefitType + + +class BenefitGrantDiscordPropertiesUpdate(TypedDict): + account_id: str + + +class BenefitGrantDiscordUpdate(BenefitGrantUpdateBase): + properties: BenefitGrantDiscordPropertiesUpdate + + def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.discord]: + return CustomerOAuthPlatform.discord + + +class BenefitGrantGitHubRepositoryPropertiesUpdate(TypedDict): + account_id: str + + +class BenefitGrantGitHubRepositoryUpdate(BenefitGrantUpdateBase): + properties: BenefitGrantGitHubRepositoryPropertiesUpdate + + def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.github]: + return CustomerOAuthPlatform.github + + +class BenefitGrantDownloadablesUpdate(BenefitGrantUpdateBase): + pass + + +class BenefitGrantLicenseKeysUpdate(BenefitGrantUpdateBase): + pass + + +class BenefitGrantAdsUpdate(BenefitGrantUpdateBase): + pass + + +class BenefitGrantCustomUpdate(BenefitGrantUpdateBase): + pass + + +BenefitGrantUpdate = Annotated[ + BenefitGrantDiscordUpdate + | BenefitGrantGitHubRepositoryUpdate + | BenefitGrantDownloadablesUpdate + | BenefitGrantLicenseKeysUpdate + | BenefitGrantAdsUpdate + | BenefitGrantCustomUpdate, + MergeJSONSchema({"title": "BenefitGrantUpdate"}), + Discriminator("benefit_type"), +] diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py index 9edd3c59bb..ddd03da138 100644 --- a/server/polar/customer_portal/service/benefit_grant.py +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -1,12 +1,14 @@ import uuid from collections.abc import Sequence from enum import StrEnum -from typing import Any +from typing import Any, cast from sqlalchemy import Select, UnaryExpression, asc, desc, select from sqlalchemy.orm import contains_eager from polar.auth.models import AuthSubject, is_customer, is_user +from polar.customer.service import customer as customer_service +from polar.exceptions import NotPermitted, PolarRequestValidationError from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader @@ -20,6 +22,13 @@ UserCustomer, ) from polar.models.benefit import BenefitType +from polar.worker import enqueue_job + +from ..schemas.benefit_grant import ( + BenefitGrantDiscordUpdate, + BenefitGrantGitHubRepositoryUpdate, + BenefitGrantUpdate, +) class CustomerBenefitGrantSortProperty(StrEnum): @@ -89,6 +98,67 @@ async def get_by_id( result = await session.execute(statement) return result.scalar_one_or_none() + async def update( + self, + session: AsyncSession, + benefit_grant: BenefitGrant, + benefit_grant_update: BenefitGrantUpdate, + ) -> BenefitGrant: + if benefit_grant.is_revoked: + raise NotPermitted("Cannot update a revoked benefit grant.") + + if benefit_grant_update.benefit_type != benefit_grant.benefit.type: + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "benefit_type"), + "msg": "Benefit type must match the existing granted benefit type.", + "input": benefit_grant_update.benefit_type, + } + ] + ) + + if isinstance(benefit_grant_update, BenefitGrantDiscordUpdate) or isinstance( + benefit_grant_update, BenefitGrantGitHubRepositoryUpdate + ): + account_id = benefit_grant_update.properties["account_id"] + platform = benefit_grant_update.get_oauth_platform() + + customer = await customer_service.get(session, benefit_grant.customer_id) + assert customer is not None + + oauth_account = customer.get_oauth_account(account_id, platform) + if oauth_account is None: + raise PolarRequestValidationError( + [ + { + "type": "value_error", + "loc": ("body", "properties", "account_id"), + "msg": "OAuth account does not exist.", + "input": account_id, + } + ] + ) + + benefit_grant.properties = cast( + Any, + { + **benefit_grant.properties, + **benefit_grant_update.properties, + }, + ) + + enqueue_job("benefit.update", benefit_grant.id) + + for attr, value in benefit_grant_update.model_dump( + exclude_unset=True, exclude={"properties", "benefit_type"} + ).items(): + setattr(benefit_grant, attr, value) + + session.add(benefit_grant) + return benefit_grant + def _get_readable_benefit_grant_statement( self, auth_subject: AuthSubject[User | Customer] ) -> Select[tuple[BenefitGrant]]: From 208d0f41f06448558a653b35a6ce79bfbc4fb29f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 9 Dec 2024 13:53:51 +0100 Subject: [PATCH 19/47] server: remove user ads endpoints --- server/polar/auth/scope.py | 7 - .../customer_portal/schemas/benefit_grant.py | 10 +- server/polar/user/auth.py | 23 -- server/polar/user/endpoints/__init__.py | 2 - server/polar/user/endpoints/advertisement.py | 201 ----------- server/polar/user/service/advertisement.py | 186 ---------- server/tests/advertisements/__init__.py | 0 server/tests/advertisements/conftest.py | 30 -- server/tests/advertisements/test_endpoints.py | 100 ------ server/tests/advertisements/test_service.py | 108 ------ server/tests/user/service/__init__.py | 0 .../tests/user/service/test_advertisement.py | 326 ------------------ 12 files changed, 6 insertions(+), 987 deletions(-) delete mode 100644 server/polar/user/auth.py delete mode 100644 server/polar/user/endpoints/advertisement.py delete mode 100644 server/polar/user/service/advertisement.py delete mode 100644 server/tests/advertisements/__init__.py delete mode 100644 server/tests/advertisements/conftest.py delete mode 100644 server/tests/advertisements/test_endpoints.py delete mode 100644 server/tests/advertisements/test_service.py delete mode 100644 server/tests/user/service/__init__.py delete mode 100644 server/tests/user/service/test_advertisement.py diff --git a/server/polar/auth/scope.py b/server/polar/auth/scope.py index 92e90fa9ae..8dd438a6dc 100644 --- a/server/polar/auth/scope.py +++ b/server/polar/auth/scope.py @@ -62,9 +62,6 @@ class Scope(StrEnum): issues_read = "issues:read" issues_write = "issues:write" - user_advertisement_campaigns_read = "user:advertisement_campaigns:read" - user_advertisement_campaigns_write = "user:advertisement_campaigns:write" - customer_portal_read = "customer_portal:read" customer_portal_write = "customer_portal:write" @@ -116,10 +113,6 @@ def __get_pydantic_json_schema__( Scope.webhooks_write: "Create or modify webhooks", Scope.customer_portal_read: "Read your orders, subscriptions and benefits", Scope.customer_portal_write: "Create or modify your orders, subscriptions and benefits", - Scope.user_advertisement_campaigns_read: "Read your advertisement campaigns", - Scope.user_advertisement_campaigns_write: ( - "Create or modify your advertisement campaigns" - ), } diff --git a/server/polar/customer_portal/schemas/benefit_grant.py b/server/polar/customer_portal/schemas/benefit_grant.py index c6d5985479..48771f7b8f 100644 --- a/server/polar/customer_portal/schemas/benefit_grant.py +++ b/server/polar/customer_portal/schemas/benefit_grant.py @@ -93,6 +93,7 @@ class BenefitGrantDiscordPropertiesUpdate(TypedDict): class BenefitGrantDiscordUpdate(BenefitGrantUpdateBase): + benefit_type: Literal[BenefitType.discord] properties: BenefitGrantDiscordPropertiesUpdate def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.discord]: @@ -104,6 +105,7 @@ class BenefitGrantGitHubRepositoryPropertiesUpdate(TypedDict): class BenefitGrantGitHubRepositoryUpdate(BenefitGrantUpdateBase): + benefit_type: Literal[BenefitType.github_repository] properties: BenefitGrantGitHubRepositoryPropertiesUpdate def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.github]: @@ -111,19 +113,19 @@ def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.github]: class BenefitGrantDownloadablesUpdate(BenefitGrantUpdateBase): - pass + benefit_type: Literal[BenefitType.downloadables] class BenefitGrantLicenseKeysUpdate(BenefitGrantUpdateBase): - pass + benefit_type: Literal[BenefitType.license_keys] class BenefitGrantAdsUpdate(BenefitGrantUpdateBase): - pass + benefit_type: Literal[BenefitType.ads] class BenefitGrantCustomUpdate(BenefitGrantUpdateBase): - pass + benefit_type: Literal[BenefitType.custom] BenefitGrantUpdate = Annotated[ diff --git a/server/polar/user/auth.py b/server/polar/user/auth.py deleted file mode 100644 index 3deb66cfdf..0000000000 --- a/server/polar/user/auth.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Annotated - -from fastapi import Depends - -from polar.auth.dependencies import Authenticator -from polar.auth.models import AuthSubject, User -from polar.auth.scope import Scope - -_UserAdvertisementCampaignsRead = Authenticator( - required_scopes={Scope.web_default, Scope.user_advertisement_campaigns_read}, - allowed_subjects={User}, -) -UserAdvertisementCampaignsRead = Annotated[ - AuthSubject[User], Depends(_UserAdvertisementCampaignsRead) -] - -_UserAdvertisementCampaignsWrite = Authenticator( - required_scopes={Scope.web_default, Scope.user_advertisement_campaigns_write}, - allowed_subjects={User}, -) -UserAdvertisementCampaignsWrite = Annotated[ - AuthSubject[User], Depends(_UserAdvertisementCampaignsWrite) -] diff --git a/server/polar/user/endpoints/__init__.py b/server/polar/user/endpoints/__init__.py index 439f993deb..e9d0e36ee1 100644 --- a/server/polar/user/endpoints/__init__.py +++ b/server/polar/user/endpoints/__init__.py @@ -4,13 +4,11 @@ from polar.customer_portal.endpoints.subscription import router as subscription_router from polar.routing import APIRouter -from .advertisement import router as advertisement_router from .user import router as user_router router = APIRouter(prefix="/users", tags=["users"]) router.include_router(user_router) -router.include_router(advertisement_router) # Include customer portal endpoints for backwards compatibility router.include_router(order_router, deprecated=True, include_in_schema=False) diff --git a/server/polar/user/endpoints/advertisement.py b/server/polar/user/endpoints/advertisement.py deleted file mode 100644 index 5e9bc3802c..0000000000 --- a/server/polar/user/endpoints/advertisement.py +++ /dev/null @@ -1,201 +0,0 @@ -from typing import Annotated - -from fastapi import Depends, Path -from pydantic import UUID4 - -from polar.exceptions import ResourceNotFound -from polar.kit.db.postgres import AsyncSession -from polar.kit.pagination import ListResource, PaginationParamsQuery -from polar.kit.sorting import Sorting, SortingGetter -from polar.models import AdvertisementCampaign -from polar.openapi import APITag -from polar.postgres import get_db_session -from polar.routing import APIRouter - -from .. import auth -from ..schemas.advertisement import ( - UserAdvertisementCampaign, - UserAdvertisementCampaignCreate, - UserAdvertisementCampaignEnable, - UserAdvertisementCampaignUpdate, -) -from ..service.advertisement import UserAdvertisementSortProperty -from ..service.advertisement import user_advertisement as user_advertisement_service - -router = APIRouter(prefix="/advertisements", tags=["advertisements", APITag.documented]) - -AdvertisementCampaignID = Annotated[ - UUID4, Path(description="The advertisement campaign ID.") -] -AdvertisementCampaignNotFound = { - "description": "Advertisement campaign not found.", - "model": ResourceNotFound.schema(), -} - -ListSorting = Annotated[ - list[Sorting[UserAdvertisementSortProperty]], - Depends(SortingGetter(UserAdvertisementSortProperty, ["-created_at"])), -] - - -@router.get( - "/", - summary="List Advertisements", - response_model=ListResource[UserAdvertisementCampaign], -) -async def list( - auth_subject: auth.UserAdvertisementCampaignsRead, - pagination: PaginationParamsQuery, - sorting: ListSorting, - session: AsyncSession = Depends(get_db_session), -) -> ListResource[UserAdvertisementCampaign]: - """List advertisement campaigns.""" - results, count = await user_advertisement_service.list( - session, - auth_subject, - pagination=pagination, - sorting=sorting, - ) - - return ListResource.from_paginated_results( - [UserAdvertisementCampaign.model_validate(result) for result in results], - count, - pagination, - ) - - -@router.get( - "/{id}", - summary="Get Advertisement", - response_model=UserAdvertisementCampaign, - responses={404: AdvertisementCampaignNotFound}, -) -async def get( - id: AdvertisementCampaignID, - auth_subject: auth.UserAdvertisementCampaignsRead, - session: AsyncSession = Depends(get_db_session), -) -> AdvertisementCampaign: - """Get an advertisement campaign by ID.""" - advertisement_campaign = await user_advertisement_service.get_by_id( - session, auth_subject, id - ) - - if advertisement_campaign is None: - raise ResourceNotFound() - - return advertisement_campaign - - -@router.post( - "/", - summary="Create Advertisement", - response_model=UserAdvertisementCampaign, - status_code=201, - responses={201: {"description": "Advertisement campaign created."}}, -) -async def create( - advertisement_campaign_create: UserAdvertisementCampaignCreate, - auth_subject: auth.UserAdvertisementCampaignsWrite, - session: AsyncSession = Depends(get_db_session), -) -> AdvertisementCampaign: - """Create an advertisement campaign.""" - return await user_advertisement_service.create( - session, - auth_subject, - advertisement_campaign_create=advertisement_campaign_create, - ) - - -@router.patch( - "/{id}", - summary="Update Advertisement", - response_model=UserAdvertisementCampaign, - responses={ - 200: {"description": "Advertisement campaign updated."}, - 404: AdvertisementCampaignNotFound, - }, -) -async def update( - id: AdvertisementCampaignID, - advertisement_campaign_update: UserAdvertisementCampaignUpdate, - auth_subject: auth.UserAdvertisementCampaignsWrite, - session: AsyncSession = Depends(get_db_session), -) -> AdvertisementCampaign: - """Update an advertisement campaign.""" - advertisement_campaign = await user_advertisement_service.get_by_id( - session, auth_subject, id - ) - - if advertisement_campaign is None: - raise ResourceNotFound() - - return await user_advertisement_service.update( - session, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_update=advertisement_campaign_update, - ) - - -@router.post( - "/{id}/enable", - summary="Enable Advertisement", - status_code=204, - responses={ - 204: {"description": "Advertisement campaign enabled on benefit."}, - 404: AdvertisementCampaignNotFound, - }, -) -async def enable( - id: AdvertisementCampaignID, - advertisement_campaign_enable: UserAdvertisementCampaignEnable, - auth_subject: auth.UserAdvertisementCampaignsWrite, - session: AsyncSession = Depends(get_db_session), -) -> None: - """Enable an advertisement campaign on a granted benefit.""" - advertisement_campaign = await user_advertisement_service.get_by_id( - session, auth_subject, id - ) - - if advertisement_campaign is None: - raise ResourceNotFound() - - await user_advertisement_service.enable( - session, - auth_subject, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_enable=advertisement_campaign_enable, - ) - - return None - - -@router.delete( - "/{id}", - summary="Delete Advertisement", - responses={ - 204: {"description": "Advertisement campaign deleted."}, - 404: AdvertisementCampaignNotFound, - }, -) -async def delete( - id: AdvertisementCampaignID, - auth_subject: auth.UserAdvertisementCampaignsWrite, - session: AsyncSession = Depends(get_db_session), -) -> None: - """ - Delete an advertisement campaign. - - It'll be automatically disabled on all granted benefits. - """ - advertisement_campaign = await user_advertisement_service.get_by_id( - session, auth_subject, id - ) - - if advertisement_campaign is None: - raise ResourceNotFound() - - await user_advertisement_service.delete( - session, advertisement_campaign=advertisement_campaign - ) - - return None diff --git a/server/polar/user/service/advertisement.py b/server/polar/user/service/advertisement.py deleted file mode 100644 index 409376ec9b..0000000000 --- a/server/polar/user/service/advertisement.py +++ /dev/null @@ -1,186 +0,0 @@ -import uuid -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from sqlalchemy import Select, UnaryExpression, asc, desc, select, update - -from polar.auth.models import AuthSubject -from polar.customer_portal.service.benefit_grant import ( - customer_benefit_grant as customer_benefit_grant_service, -) -from polar.exceptions import PolarError, PolarRequestValidationError -from polar.kit.db.postgres import AsyncSession -from polar.kit.pagination import PaginationParams, paginate -from polar.kit.services import ResourceServiceReader -from polar.kit.sorting import Sorting -from polar.models import AdvertisementCampaign, BenefitGrant, User -from polar.models.benefit import BenefitType - -from ..schemas.advertisement import ( - UserAdvertisementCampaignCreate, - UserAdvertisementCampaignEnable, - UserAdvertisementCampaignUpdate, -) - - -class UserAdvertisementError(PolarError): ... - - -class UserAdvertisementSortProperty(StrEnum): - created_at = "created_at" - views = "views" - clicks = "clicks" - - -class UserAdvertisementService(ResourceServiceReader[AdvertisementCampaign]): - async def list( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - *, - pagination: PaginationParams, - sorting: list[Sorting[UserAdvertisementSortProperty]] = [ - (UserAdvertisementSortProperty.created_at, True) - ], - ) -> tuple[Sequence[AdvertisementCampaign], int]: - statement = self._get_readable_advertisement_statement(auth_subject) - - order_by_clauses: list[UnaryExpression[Any]] = [] - for criterion, is_desc in sorting: - clause_function = desc if is_desc else asc - if criterion == UserAdvertisementSortProperty.created_at: - order_by_clauses.append( - clause_function(AdvertisementCampaign.created_at) - ) - elif criterion == UserAdvertisementSortProperty.views: - order_by_clauses.append(clause_function(AdvertisementCampaign.views)) - elif criterion == UserAdvertisementSortProperty.clicks: - order_by_clauses.append(clause_function(AdvertisementCampaign.clicks)) - statement = statement.order_by(*order_by_clauses) - - return await paginate(session, statement, pagination=pagination) - - async def get_by_id( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - id: uuid.UUID, - ) -> AdvertisementCampaign | None: - statement = self._get_readable_advertisement_statement(auth_subject).where( - AdvertisementCampaign.id == id - ) - - result = await session.execute(statement) - return result.scalar_one_or_none() - - async def create( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - *, - advertisement_campaign_create: UserAdvertisementCampaignCreate, - ) -> AdvertisementCampaign: - advertisement_campaign = AdvertisementCampaign( - user=auth_subject.subject, **advertisement_campaign_create.model_dump() - ) - session.add(advertisement_campaign) - await session.flush() - - return advertisement_campaign - - async def update( - self, - session: AsyncSession, - *, - advertisement_campaign: AdvertisementCampaign, - advertisement_campaign_update: UserAdvertisementCampaignUpdate, - ) -> AdvertisementCampaign: - for attr, value in advertisement_campaign_update.model_dump( - exclude_unset=True - ).items(): - setattr(advertisement_campaign, attr, value) - - session.add(advertisement_campaign) - - return advertisement_campaign - - async def enable( - self, - session: AsyncSession, - auth_subject: AuthSubject[User], - *, - advertisement_campaign: AdvertisementCampaign, - advertisement_campaign_enable: UserAdvertisementCampaignEnable, - ) -> Sequence[BenefitGrant]: - grant = await customer_benefit_grant_service.get_by_id( - session, auth_subject, advertisement_campaign_enable.benefit_id - ) - benefit = grant.benefit - - if benefit is None: - raise PolarRequestValidationError( - [ - { - "type": "value_error", - "msg": "Benefit does not exist or is not granted.", - "loc": ("body", "benefit_id"), - "input": advertisement_campaign_enable.benefit_id, - } - ] - ) - - if benefit.type != BenefitType.ads: - raise PolarRequestValidationError( - [ - { - "type": "value_error", - "msg": "Not an advertisement benefit.", - "loc": ("body", "benefit_id"), - "input": advertisement_campaign_enable.benefit_id, - } - ] - ) - - updated_grants: list[BenefitGrant] = [] - for grant in benefit.grants: - # Those are guaranteed by the query in get_by_id, but let's be explicit - assert grant.user_id == auth_subject.subject.id - assert grant.is_granted - grant.properties = { - "advertisement_campaign_id": str(advertisement_campaign.id) - } - session.add(grant) - updated_grants.append(grant) - - return updated_grants - - async def delete( - self, session: AsyncSession, *, advertisement_campaign: AdvertisementCampaign - ) -> AdvertisementCampaign: - advertisement_campaign.set_deleted_at() - session.add(advertisement_campaign) - - statement = ( - update(BenefitGrant) - .where( - BenefitGrant.properties["advertisement_campaign_id"].astext - == str(advertisement_campaign.id), - ) - .values(properties={}) - ) - await session.execute(statement) - - return advertisement_campaign - - def _get_readable_advertisement_statement( - self, auth_subject: AuthSubject[User] - ) -> Select[tuple[AdvertisementCampaign]]: - statement = select(AdvertisementCampaign).where( - AdvertisementCampaign.deleted_at.is_(None), - AdvertisementCampaign.user_id == auth_subject.subject.id, - ) - return statement - - -user_advertisement = UserAdvertisementService(AdvertisementCampaign) diff --git a/server/tests/advertisements/__init__.py b/server/tests/advertisements/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/server/tests/advertisements/conftest.py b/server/tests/advertisements/conftest.py deleted file mode 100644 index 7fd79f6e18..0000000000 --- a/server/tests/advertisements/conftest.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest_asyncio - -from polar.models import Benefit, Organization -from polar.models.benefit import BenefitType -from tests.fixtures.database import SaveFixture -from tests.fixtures.random_objects import create_benefit - - -@pytest_asyncio.fixture(autouse=True) -async def ads_benefit_organization( - save_fixture: SaveFixture, organization: Organization -) -> Benefit: - return await create_benefit( - save_fixture, - organization=organization, - type=BenefitType.ads, - properties={"image_height": 100, "image_width": 100}, - ) - - -@pytest_asyncio.fixture(autouse=True) -async def ads_benefit_organization_second( - save_fixture: SaveFixture, organization_second: Organization -) -> Benefit: - return await create_benefit( - save_fixture, - organization=organization_second, - type=BenefitType.ads, - properties={"image_height": 100, "image_width": 100}, - ) diff --git a/server/tests/advertisements/test_endpoints.py b/server/tests/advertisements/test_endpoints.py deleted file mode 100644 index 9e41d20a9a..0000000000 --- a/server/tests/advertisements/test_endpoints.py +++ /dev/null @@ -1,100 +0,0 @@ -import uuid - -import pytest -from httpx import AsyncClient - -from polar.models import Organization, User -from polar.models.benefit import BenefitAds -from tests.fixtures.database import SaveFixture -from tests.fixtures.random_objects import ( - create_advertisement_campaign, - create_benefit, - create_benefit_grant, -) - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestListAdvertisementCampaigns: - async def test_not_existing_benefit(self, client: AsyncClient) -> None: - response = await client.get( - "/v1/advertisements/", params={"benefit_id": str(uuid.uuid4())} - ) - - assert response.status_code == 422 - - async def test_not_ads_benefit( - self, client: AsyncClient, save_fixture: SaveFixture, organization: Organization - ) -> None: - benefit = await create_benefit(save_fixture, organization=organization) - response = await client.get( - "/v1/advertisements/", params={"benefit_id": str(benefit.id)} - ) - - assert response.status_code == 422 - - async def test_no_campaign( - self, client: AsyncClient, ads_benefit_organization: BenefitAds - ) -> None: - response = await client.get( - "/v1/advertisements/", - params={"benefit_id": str(ads_benefit_organization.id)}, - ) - - assert response.status_code == 200 - - json = response.json() - assert json["pagination"]["total_count"] == 0 - assert json["items"] == [] - assert json["dimensions"] == [ - ads_benefit_organization.properties["image_width"], - ads_benefit_organization.properties["image_height"], - ] - - async def test_with_campaigns( - self, - client: AsyncClient, - save_fixture: SaveFixture, - user: User, - user_second: User, - ads_benefit_organization: BenefitAds, - ) -> None: - campaign1 = await create_advertisement_campaign(save_fixture, user=user) - await create_benefit_grant( - save_fixture, - user=user, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign1.id)}, - ) - await create_benefit_grant( - save_fixture, - user=user, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign1.id)}, - ) - - campaign2 = await create_advertisement_campaign(save_fixture, user=user_second) - await create_benefit_grant( - save_fixture, - user=user_second, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign2.id)}, - ) - - response = await client.get( - "/v1/advertisements/", - params={"benefit_id": str(ads_benefit_organization.id)}, - ) - - assert response.status_code == 200 - - json = response.json() - assert json["pagination"]["total_count"] == 2 - assert len(json["items"]) == 2 - assert json["dimensions"] == [ - ads_benefit_organization.properties["image_width"], - ads_benefit_organization.properties["image_height"], - ] diff --git a/server/tests/advertisements/test_service.py b/server/tests/advertisements/test_service.py deleted file mode 100644 index 7e2d88f00b..0000000000 --- a/server/tests/advertisements/test_service.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest -import pytest_asyncio - -from polar.advertisement.service import ( - advertisement_campaign as advertisement_campaign_service, -) -from polar.auth.models import AuthSubject -from polar.kit.db.postgres import AsyncSession -from polar.kit.pagination import PaginationParams -from polar.models import Benefit, Organization, User, UserOrganization -from polar.models.benefit import BenefitType -from tests.fixtures.database import SaveFixture -from tests.fixtures.random_objects import ( - create_advertisement_campaign, - create_benefit, - create_benefit_grant, -) - - -@pytest_asyncio.fixture(autouse=True) -async def ads_benefit_organization( - save_fixture: SaveFixture, organization: Organization -) -> Benefit: - return await create_benefit( - save_fixture, organization=organization, type=BenefitType.ads - ) - - -@pytest_asyncio.fixture(autouse=True) -async def ads_benefit_organization_second( - save_fixture: SaveFixture, organization_second: Organization -) -> Benefit: - return await create_benefit( - save_fixture, organization=organization_second, type=BenefitType.ads - ) - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestList: - async def test_valid( - self, - auth_subject: AuthSubject[User | Organization], - save_fixture: SaveFixture, - session: AsyncSession, - user_second: User, - user: User, - user_organization: UserOrganization, - ads_benefit_organization: Benefit, - ads_benefit_organization_second: Benefit, - ) -> None: - campaign1 = await create_advertisement_campaign(save_fixture, user=user) - await create_benefit_grant( - save_fixture, - user=user, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign1.id)}, - ) - await create_benefit_grant( - save_fixture, - user=user, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign1.id)}, - ) - - campaign2 = await create_advertisement_campaign(save_fixture, user=user_second) - await create_benefit_grant( - save_fixture, - user=user_second, - benefit=ads_benefit_organization, - granted=True, - properties={"advertisement_campaign_id": str(campaign2.id)}, - ) - - campaign3 = await create_advertisement_campaign(save_fixture, user=user) - await create_benefit_grant( - save_fixture, - user=user, - benefit=ads_benefit_organization_second, - granted=True, - properties={"advertisement_campaign_id": str(campaign3.id)}, - ) - - advertisement_campaigns, count = await advertisement_campaign_service.list( - session, - benefit_id=ads_benefit_organization.id, - pagination=PaginationParams(1, 10), - ) - - assert count == 2 - assert len(advertisement_campaigns) == 2 - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestTrackView: - async def test_valid( - self, save_fixture: SaveFixture, session: AsyncSession, user: User - ) -> None: - campaign = await create_advertisement_campaign(save_fixture, user=user) - assert campaign.views == 0 - - updated_campaign = await advertisement_campaign_service.track_view( - session, campaign - ) - assert updated_campaign.views == 1 diff --git a/server/tests/user/service/__init__.py b/server/tests/user/service/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/server/tests/user/service/test_advertisement.py b/server/tests/user/service/test_advertisement.py deleted file mode 100644 index d6e29a3b5f..0000000000 --- a/server/tests/user/service/test_advertisement.py +++ /dev/null @@ -1,326 +0,0 @@ -from typing import cast - -import pytest -from pydantic_core import Url - -from polar.auth.models import AuthSubject -from polar.exceptions import PolarRequestValidationError -from polar.kit.db.postgres import AsyncSession -from polar.kit.pagination import PaginationParams -from polar.kit.sorting import Sorting -from polar.models import Organization, User -from polar.models.benefit import BenefitType -from polar.models.benefit_grant import BenefitGrantAdsProperties -from polar.user.schemas.advertisement import ( - UserAdvertisementCampaignCreate, - UserAdvertisementCampaignEnable, - UserAdvertisementCampaignUpdate, -) -from polar.user.service.advertisement import UserAdvertisementSortProperty -from polar.user.service.advertisement import ( - user_advertisement as user_advertisement_service, -) -from tests.fixtures.database import SaveFixture -from tests.fixtures.random_objects import ( - create_advertisement_campaign, - create_benefit, - create_benefit_grant, -) - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestList: - @pytest.mark.auth - async def test_other_user( - self, - auth_subject: AuthSubject[User], - save_fixture: SaveFixture, - session: AsyncSession, - user_second: User, - ) -> None: - await create_advertisement_campaign(save_fixture, user=user_second) - - advertisement_campaigns, count = await user_advertisement_service.list( - session, auth_subject, pagination=PaginationParams(1, 10) - ) - - assert count == 0 - assert len(advertisement_campaigns) == 0 - - @pytest.mark.auth - async def test_user( - self, - auth_subject: AuthSubject[User], - save_fixture: SaveFixture, - session: AsyncSession, - user: User, - ) -> None: - await create_advertisement_campaign(save_fixture, user=user) - - advertisement_campaigns, count = await user_advertisement_service.list( - session, auth_subject, pagination=PaginationParams(1, 10) - ) - - assert count == 1 - assert len(advertisement_campaigns) == 1 - - @pytest.mark.parametrize( - "sorting", - [ - [("created_at", True)], - [("views", True)], - [("clicks", False)], - ], - ) - @pytest.mark.auth - async def test_sorting( - self, - sorting: list[Sorting[UserAdvertisementSortProperty]], - auth_subject: AuthSubject[User], - save_fixture: SaveFixture, - session: AsyncSession, - user: User, - ) -> None: - await create_advertisement_campaign(save_fixture, user=user) - - advertisement_campaigns, count = await user_advertisement_service.list( - session, auth_subject, pagination=PaginationParams(1, 10), sorting=sorting - ) - - assert count == 1 - assert len(advertisement_campaigns) == 1 - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestGetById: - @pytest.mark.auth - async def test_other_user( - self, - auth_subject: AuthSubject[User], - save_fixture: SaveFixture, - session: AsyncSession, - user_second: User, - ) -> None: - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user_second - ) - - result = await user_advertisement_service.get_by_id( - session, auth_subject, advertisement_campaign.id - ) - assert result is None - - @pytest.mark.auth - async def test_user( - self, - auth_subject: AuthSubject[User], - save_fixture: SaveFixture, - session: AsyncSession, - user: User, - ) -> None: - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - - result = await user_advertisement_service.get_by_id( - session, auth_subject, advertisement_campaign.id - ) - - assert result is not None - assert result.id == advertisement_campaign.id - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestCreate: - @pytest.mark.auth - async def test_user( - self, auth_subject: AuthSubject[User], session: AsyncSession, user: User - ) -> None: - advertisement_campaign = await user_advertisement_service.create( - session, - auth_subject, - advertisement_campaign_create=UserAdvertisementCampaignCreate( - image_url=Url("https://loremflickr.com/g/320/240/cat"), - text="Test", - link_url=Url("https://example.com"), - ), - ) - - assert advertisement_campaign is not None - assert advertisement_campaign.user_id == user.id - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestUpdate: - @pytest.mark.auth - async def test_user( - self, - save_fixture: SaveFixture, - session: AsyncSession, - user: User, - ) -> None: - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - - advertisement_campaign = await user_advertisement_service.update( - session, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_update=UserAdvertisementCampaignUpdate( - image_url=Url("https://loremflickr.com/g/320/240/kitten"), - ), - ) - - assert ( - str(advertisement_campaign.image_url) - == "https://loremflickr.com/g/320/240/kitten" - ) - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestEnable: - @pytest.mark.auth - async def test_not_granted_benefit( - self, - save_fixture: SaveFixture, - session: AsyncSession, - auth_subject: AuthSubject[User], - organization: Organization, - user: User, - ) -> None: - benefit = await create_benefit( - save_fixture, organization=organization, type=BenefitType.ads - ) - - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - with pytest.raises(PolarRequestValidationError): - await user_advertisement_service.enable( - session=session, - auth_subject=auth_subject, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_enable=UserAdvertisementCampaignEnable( - benefit_id=benefit.id - ), - ) - - @pytest.mark.auth - async def test_not_ads_benefit( - self, - save_fixture: SaveFixture, - session: AsyncSession, - auth_subject: AuthSubject[User], - organization: Organization, - user: User, - ) -> None: - benefit = await create_benefit( - save_fixture, organization=organization, type=BenefitType.custom - ) - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - with pytest.raises(PolarRequestValidationError): - await user_advertisement_service.enable( - session=session, - auth_subject=auth_subject, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_enable=UserAdvertisementCampaignEnable( - benefit_id=benefit.id - ), - ) - - @pytest.mark.auth - async def test_valid( - self, - save_fixture: SaveFixture, - session: AsyncSession, - auth_subject: AuthSubject[User], - organization: Organization, - user: User, - user_second: User, - ) -> None: - benefit = await create_benefit( - save_fixture, organization=organization, type=BenefitType.ads - ) - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - - user_grant1 = await create_benefit_grant( - save_fixture, user, benefit, granted=True - ) - user_grant2 = await create_benefit_grant( - save_fixture, user, benefit, granted=True - ) - await create_benefit_grant(save_fixture, user_second, benefit, granted=True) - - grants = await user_advertisement_service.enable( - session=session, - auth_subject=auth_subject, - advertisement_campaign=advertisement_campaign, - advertisement_campaign_enable=UserAdvertisementCampaignEnable( - benefit_id=benefit.id - ), - ) - - assert len(grants) == 2 - assert user_grant1 in grants - assert user_grant2 in grants - for grant in grants: - properties = cast(BenefitGrantAdsProperties, grant.properties) - assert properties["advertisement_campaign_id"] == str( - advertisement_campaign.id - ) - - -@pytest.mark.asyncio -@pytest.mark.skip_db_asserts -class TestDelete: - @pytest.mark.auth - async def test_valid( - self, - save_fixture: SaveFixture, - session: AsyncSession, - auth_subject: AuthSubject[User], - organization: Organization, - user: User, - user_second: User, - ) -> None: - benefit = await create_benefit( - save_fixture, organization=organization, type=BenefitType.ads - ) - advertisement_campaign = await create_advertisement_campaign( - save_fixture, user=user - ) - - user_grant1 = await create_benefit_grant( - save_fixture, - user, - benefit, - granted=True, - properties={"advertisement_campaign_id": str(advertisement_campaign.id)}, - ) - user_grant2 = await create_benefit_grant( - save_fixture, - user, - benefit, - granted=True, - properties={"advertisement_campaign_id": str(advertisement_campaign.id)}, - ) - - deleted_advertisement_campaign = await user_advertisement_service.delete( - session, advertisement_campaign=advertisement_campaign - ) - - assert deleted_advertisement_campaign.deleted_at is not None - - for grant in [user_grant1, user_grant2]: - updated_grant = await session.get(grant.__class__, grant.id) - assert updated_grant is not None - assert updated_grant.properties == {} From f84e3f5d940cca419ed3a8fc87661295bd2e2784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 10 Dec 2024 11:31:24 +0100 Subject: [PATCH 20/47] server: store the legacy user id on customer and use it in schemas --- server/polar/customer_portal/schemas/order.py | 9 ++++--- .../customer_portal/schemas/subscription.py | 9 ++++--- server/polar/customer_portal/service/order.py | 2 ++ .../customer_portal/service/subscription.py | 4 ++- server/polar/license_key/schemas.py | 13 +++++++--- server/polar/models/customer.py | 26 +++++++++++++++++++ server/polar/models/user_customer.py | 3 ++- server/polar/order/schemas.py | 17 +++++++----- server/polar/subscription/schemas.py | 17 ++++++++---- 9 files changed, 76 insertions(+), 24 deletions(-) diff --git a/server/polar/customer_portal/schemas/order.py b/server/polar/customer_portal/schemas/order.py index 6ce740d208..74767ecd6f 100644 --- a/server/polar/customer_portal/schemas/order.py +++ b/server/polar/customer_portal/schemas/order.py @@ -1,4 +1,4 @@ -from pydantic import UUID4, Field +from pydantic import UUID4, AliasPath, Field from polar.kit.schemas import Schema, TimestampedSchema from polar.organization.schemas import Organization @@ -19,9 +19,6 @@ class CustomerOrderBase(TimestampedSchema): currency: str customer_id: UUID4 - user_id: UUID4 = Field( - validation_alias="customer_id", deprecated="Use `customer_id`." - ) product_id: UUID4 product_price_id: UUID4 subscription_id: UUID4 | None @@ -38,6 +35,10 @@ class CustomerOrderSubscription(SubscriptionBase): ... class CustomerOrder(CustomerOrderBase): + user_id: UUID4 = Field( + validation_alias=AliasPath("customer", "legacy_user_id"), + deprecated="Use `customer_id`.", + ) product: CustomerOrderProduct product_price: ProductPrice subscription: CustomerOrderSubscription | None diff --git a/server/polar/customer_portal/schemas/subscription.py b/server/polar/customer_portal/schemas/subscription.py index cc9f64a71a..463cfe7737 100644 --- a/server/polar/customer_portal/schemas/subscription.py +++ b/server/polar/customer_portal/schemas/subscription.py @@ -1,6 +1,6 @@ from datetime import datetime -from pydantic import UUID4, Field +from pydantic import UUID4, AliasPath, Field from polar.kit.schemas import Schema from polar.models.subscription import SubscriptionStatus @@ -24,9 +24,6 @@ class CustomerSubscriptionBase(SubscriptionBase): ended_at: datetime | None customer_id: UUID4 - user_id: UUID4 = Field( - validation_alias="customer_id", deprecated="Use `customer_id`." - ) product_id: UUID4 price_id: UUID4 @@ -39,6 +36,10 @@ class CustomerSubscriptionProduct(ProductBase): class CustomerSubscription(CustomerSubscriptionBase): + user_id: UUID4 = Field( + validation_alias=AliasPath("customer", "legacy_user_id"), + deprecated="Use `customer_id`.", + ) product: CustomerSubscriptionProduct price: ProductPrice diff --git a/server/polar/customer_portal/service/order.py b/server/polar/customer_portal/service/order.py index 688cd78194..2bd358f702 100644 --- a/server/polar/customer_portal/service/order.py +++ b/server/polar/customer_portal/service/order.py @@ -64,6 +64,7 @@ async def list( statement = statement.join( Organization, onclause=Product.organization_id == Organization.id ).options( + joinedload(Order.customer), joinedload(Order.subscription), contains_eager(Order.product).options( selectinload(Product.product_medias), @@ -123,6 +124,7 @@ async def get_by_id( self._get_readable_order_statement(auth_subject) .where(Order.id == id) .options( + joinedload(Order.customer), joinedload(Order.product_price), joinedload(Order.subscription), contains_eager(Order.product).options( diff --git a/server/polar/customer_portal/service/subscription.py b/server/polar/customer_portal/service/subscription.py index 531c4373e2..1e20b95575 100644 --- a/server/polar/customer_portal/service/subscription.py +++ b/server/polar/customer_portal/service/subscription.py @@ -85,10 +85,11 @@ async def list( statement.join(Product, onclause=Subscription.product_id == Product.id) .join(Organization, onclause=Product.organization_id == Organization.id) .options( + joinedload(Subscription.customer), contains_eager(Subscription.product).options( selectinload(Product.product_medias), contains_eager(Product.organization), - ) + ), ) ) @@ -148,6 +149,7 @@ async def get_by_id( self._get_readable_subscription_statement(auth_subject) .where(Subscription.id == id) .options( + joinedload(Subscription.customer), joinedload(Subscription.product).options( selectinload(Product.product_medias), joinedload(Product.organization), diff --git a/server/polar/license_key/schemas.py b/server/polar/license_key/schemas.py index 4297263b85..34866fb88d 100644 --- a/server/polar/license_key/schemas.py +++ b/server/polar/license_key/schemas.py @@ -2,7 +2,7 @@ from typing import Any, Literal, Self from dateutil.relativedelta import relativedelta -from pydantic import UUID4, Field +from pydantic import UUID4, AliasPath, Field from polar.benefit.schemas import BenefitID from polar.exceptions import ResourceNotFound, Unauthorized @@ -70,14 +70,21 @@ class LicenseKeyCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): organization_id: UUID4 +class LicenseKeyUser(Schema): + id: UUID4 = Field(validation_alias="legacy_user_id") + email: str + public_name: str = Field(validation_alias="legacy_user_public_name") + + class LicenseKeyRead(Schema): id: UUID4 organization_id: UUID4 user_id: UUID4 = Field( - validation_alias="customer_id", deprecated="Use `customer_id`." + validation_alias=AliasPath("customer", "legacy_user_id"), + deprecated="Use `customer_id`.", ) customer_id: UUID4 - user: LicenseKeyCustomer = Field( + user: LicenseKeyUser = Field( validation_alias="customer", deprecated="Use `customer`." ) customer: LicenseKeyCustomer diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index 10cb812747..116f894be7 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -90,6 +90,22 @@ class Customer(MetadataMixin, RecordModel): "oauth_accounts", JSONB, nullable=False, default=dict ) + _legacy_user_id: Mapped[UUID | None] = mapped_column( + "legacy_user_id", + Uuid, + ForeignKey("users.id", ondelete="set null"), + nullable=True, + ) + """ + Before implementing customers, every customer was a user. This field is used to + keep track of the user that originated this customer. + + It helps us keep backwards compatibility with integrations that used the user ID as + reference to the customer. + + For new customers, this field will be null. + """ + organization_id: Mapped[UUID] = mapped_column( Uuid, ForeignKey("organizations.id", ondelete="cascade"), @@ -122,3 +138,13 @@ def remove_oauth_account( ) -> None: account_key = platform.get_account_key(account_id) self._oauth_accounts.pop(account_key, None) + + @property + def legacy_user_id(self) -> UUID: + return self._legacy_user_id or self.id + + @property + def legacy_user_public_name(self) -> str: + if self.name: + return self.name[0] + return self.email[0] diff --git a/server/polar/models/user_customer.py b/server/polar/models/user_customer.py index d6b96d171a..d25dbe53be 100644 --- a/server/polar/models/user_customer.py +++ b/server/polar/models/user_customer.py @@ -1,6 +1,6 @@ from uuid import UUID -from sqlalchemy import ForeignKey, Uuid +from sqlalchemy import ForeignKey, UniqueConstraint, Uuid from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from polar.kit.db.models.base import RecordModel @@ -11,6 +11,7 @@ class UserCustomer(RecordModel): __tablename__ = "user_customers" + __table_args__ = (UniqueConstraint("user_id", "customer_id"),) user_id: Mapped[UUID] = mapped_column( Uuid, ForeignKey("users.id", ondelete="cascade"), nullable=False diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index e2bebab5ad..c7d710c1ec 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -1,7 +1,7 @@ from typing import Annotated from babel.numbers import format_currency -from pydantic import UUID4, Field +from pydantic import UUID4, AliasPath, Field from polar.custom_field.data import CustomFieldDataOutputMixin from polar.discount.schemas import DiscountMinimal @@ -23,9 +23,6 @@ class OrderBase( billing_reason: OrderBillingReason billing_address: Address | None - user_id: UUID4 = Field( - validation_alias="customer_id", deprecated="Use `customer_id`." - ) customer_id: UUID4 product_id: UUID4 product_price_id: UUID4 @@ -50,6 +47,12 @@ class OrderCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): organization_id: UUID4 +class OrderUser(Schema): + id: UUID4 = Field(validation_alias="legacy_user_id") + email: str + public_name: str = Field(validation_alias="legacy_user_public_name") + + class OrderProduct(ProductBase): ... @@ -61,9 +64,11 @@ class OrderSubscription(SubscriptionBase, MetadataOutputMixin): ... class Order(OrderBase): customer: OrderCustomer - user: OrderCustomer = Field( - validation_alias="customer", deprecated="Use `customer`." + user_id: UUID4 = Field( + validation_alias=AliasPath("customer", "legacy_user_id"), + deprecated="Use `customer_id`.", ) + user: OrderUser = Field(validation_alias="customer", deprecated="Use `customer`.") product: OrderProduct product_price: ProductPrice discount: OrderDiscount | None diff --git a/server/polar/subscription/schemas.py b/server/polar/subscription/schemas.py index c4866fa36b..a10c3954a5 100644 --- a/server/polar/subscription/schemas.py +++ b/server/polar/subscription/schemas.py @@ -2,7 +2,7 @@ from typing import Annotated from babel.numbers import format_currency -from pydantic import UUID4, Field +from pydantic import UUID4, AliasPath, Field from polar.custom_field.data import CustomFieldDataOutputMixin from polar.discount.schemas import DiscountMinimal @@ -30,6 +30,12 @@ class SubscriptionCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): organization_id: UUID4 +class SubscriptionUser(Schema): + id: UUID4 = Field(validation_alias="legacy_user_id") + email: str + public_name: str = Field(validation_alias="legacy_user_public_name") + + class SubscriptionBase(IDSchema, TimestampedSchema): amount: int | None currency: str | None @@ -41,9 +47,6 @@ class SubscriptionBase(IDSchema, TimestampedSchema): started_at: datetime | None ended_at: datetime | None - user_id: UUID4 = Field( - validation_alias="customer_id", deprecated="Use `customer_id`." - ) customer_id: UUID4 product_id: UUID4 price_id: UUID4 @@ -67,7 +70,11 @@ def get_amount_display(self) -> str: class Subscription(CustomFieldDataOutputMixin, MetadataOutputMixin, SubscriptionBase): customer: SubscriptionCustomer - user: SubscriptionCustomer = Field( + user_id: UUID4 = Field( + validation_alias=AliasPath("customer", "legacy_user_id"), + deprecated="Use `customer_id`.", + ) + user: SubscriptionUser = Field( validation_alias="customer", deprecated="Use `customer`." ) product: Product From e4f2996c4fb6fac4837e361b9588b1ee9551a0eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 10 Dec 2024 13:44:37 +0100 Subject: [PATCH 21/47] server/benefit: handle discord refresh token error --- server/polar/benefit/benefits/discord.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/server/polar/benefit/benefits/discord.py b/server/polar/benefit/benefits/discord.py index 1356072690..77320cabf5 100644 --- a/server/polar/benefit/benefits/discord.py +++ b/server/polar/benefit/benefits/discord.py @@ -3,6 +3,7 @@ import httpx import structlog from httpx_oauth.clients.discord import DiscordOAuth2 +from httpx_oauth.oauth2 import RefreshTokenError from polar.auth.models import AuthSubject from polar.config import settings @@ -194,9 +195,20 @@ async def _get_customer_oauth_account( settings.DISCORD_CLIENT_SECRET, scopes=["identify", "email", "guilds.join"], ) - refreshed_token_data = await client.refresh_token( - oauth_account.refresh_token - ) + try: + refreshed_token_data = await client.refresh_token( + oauth_account.refresh_token + ) + except RefreshTokenError as e: + log.warning( + "Failed to refresh Discord access token", + oauth_account_id=oauth_account.account_id, + customer_id=str(customer.id), + error=str(e), + ) + raise BenefitActionRequiredError( + "The customer needs to reconnect their Discord account" + ) from e oauth_account.access_token = refreshed_token_data["access_token"] oauth_account.expires_at = refreshed_token_data["expires_at"] oauth_account.refresh_token = refreshed_token_data["refresh_token"] From 34da9b59d61a679af7be75cad1062bbe59f1b3ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 10 Dec 2024 15:41:03 +0100 Subject: [PATCH 22/47] server: implement customers migration script --- .../2024-12-10-1357_migrate_to_customers.py | 989 ++++++++++++++++++ 1 file changed, 989 insertions(+) create mode 100644 server/migrations/versions/2024-12-10-1357_migrate_to_customers.py diff --git a/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py b/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py new file mode 100644 index 0000000000..8d61c654d6 --- /dev/null +++ b/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py @@ -0,0 +1,989 @@ +"""Migrate to Customers + +Revision ID: e47b6d16d3e0 +Revises: 59538121ff3b +Create Date: 2024-12-09 13:57:40.151264 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# Polar Custom Imports +from polar.kit.address import AddressType +from polar.kit.tax import TaxIDType + +# revision identifiers, used by Alembic. +revision = "e47b6d16d3e0" +down_revision = "59538121ff3b" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # CUSTOMERS + op.create_table( + "customers", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("modified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("deleted_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("organization_id", sa.Uuid(), nullable=False), + sa.Column("email", sa.String(length=320), nullable=False), + sa.Column("email_verified", sa.Boolean(), nullable=False), + sa.Column("stripe_customer_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("billing_address", AddressType(astext_type=sa.Text()), nullable=True), + sa.Column("tax_id", TaxIDType(astext_type=sa.Text()), nullable=True), + sa.Column( + "oauth_accounts", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column("legacy_user_id", sa.Uuid(), nullable=True), + sa.Column( + "user_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + name=op.f("customers_organization_id_fkey"), + ondelete="cascade", + ), + sa.ForeignKeyConstraint( + ["legacy_user_id"], + ["users.id"], + name=op.f("customers_legacy_user_id_fkey"), + ondelete="set null", + ), + sa.PrimaryKeyConstraint("id", name=op.f("customers_pkey")), + sa.UniqueConstraint( + "stripe_customer_id", name=op.f("customers_stripe_customer_id_key") + ), + ) + op.create_index( + op.f("ix_customers_created_at"), "customers", ["created_at"], unique=False + ) + op.create_index( + op.f("ix_customers_deleted_at"), "customers", ["deleted_at"], unique=False + ) + op.create_index( + "ix_customers_email_case_insensitive", + "customers", + [sa.text("lower(email)")], + unique=False, + ) + op.create_index( + op.f("ix_customers_modified_at"), "customers", ["modified_at"], unique=False + ) + op.create_index( + "ix_customers_organization_id_email_case_insensitive", + "customers", + ["organization_id", sa.text("lower(email)")], + unique=True, + ) + + op.create_index( + op.f("ix_customers_tmp_legacy_user_id_organization_id"), + "customers", + ["legacy_user_id", "organization_id"], + unique=True, + ) + + op.execute( + """ + INSERT INTO customers ( + id, + created_at, + email, + email_verified, + stripe_customer_id, + name, + billing_address, + tax_id, + organization_id, + oauth_accounts, + user_metadata, + legacy_user_id + ) + SELECT + uuid_generate_v4(), + users.created_at, + users.email, + users.email_verified, + NULL, + NULL, + NULL, + NULL, + distinct_orders.organization_id, + '{}', + '{}', + users.id + FROM ( + SELECT DISTINCT orders.user_id, products.organization_id + FROM orders + JOIN products ON products.id = orders.product_id + ) AS distinct_orders + JOIN users ON users.id = distinct_orders.user_id; + """ + ) + op.execute( + """ + INSERT INTO customers ( + id, + created_at, + email, + email_verified, + stripe_customer_id, + name, + billing_address, + tax_id, + organization_id, + oauth_accounts, + user_metadata, + legacy_user_id + ) + SELECT + uuid_generate_v4(), + users.created_at, + users.email, + users.email_verified, + NULL, + NULL, + NULL, + NULL, + distinct_subscriptions.organization_id, + '{}', + '{}', + users.id + FROM ( + SELECT DISTINCT subscriptions.user_id, products.organization_id + FROM subscriptions + JOIN products ON products.id = subscriptions.product_id + WHERE (subscriptions.user_id, products.organization_id) NOT IN ( + SELECT legacy_user_id, organization_id + FROM customers + ) + ) AS distinct_subscriptions + JOIN users ON users.id = distinct_subscriptions.user_id; + """ + ) + op.execute( + """ + INSERT INTO customers ( + id, + created_at, + email, + email_verified, + stripe_customer_id, + name, + billing_address, + tax_id, + organization_id, + oauth_accounts, + user_metadata, + legacy_user_id + ) + SELECT + uuid_generate_v4(), + users.created_at, + users.email, + users.email_verified, + NULL, + NULL, + NULL, + NULL, + distinct_benefit_grants.organization_id, + '{}', + '{}', + users.id + FROM ( + SELECT DISTINCT benefit_grants.user_id, benefits.organization_id + FROM benefit_grants + JOIN benefits ON benefits.id = benefit_grants.benefit_id + WHERE (benefit_grants.user_id, benefits.organization_id) NOT IN ( + SELECT legacy_user_id, organization_id + FROM customers + ) + ) AS distinct_benefit_grants + JOIN users ON users.id = distinct_benefit_grants.user_id; + """ + ) + + op.execute( + """ + UPDATE customers c + SET oauth_accounts = + c.oauth_accounts || + ( + SELECT jsonb_object_agg( + oa.platform || ':' || oa.account_id, + jsonb_build_object( + 'access_token', oa.access_token, + 'account_id', oa.account_id, + 'account_username', oa.account_username, + 'expires_at', oa.expires_at, + 'refresh_token', oa.refresh_token, + 'refresh_token_expires_at', oa.refresh_token_expires_at + ) + ) + FROM oauth_accounts oa + WHERE oa.user_id = c.legacy_user_id + AND oa.platform IN ('github', 'discord') + ) + WHERE EXISTS ( + SELECT 1 + FROM oauth_accounts oa + WHERE oa.user_id = c.legacy_user_id + AND oa.platform IN ('github', 'discord') + ); + """ + ) + + op.drop_index( + "ix_customers_tmp_legacy_user_id_organization_id", table_name="customers" + ) + + # CUSTOMER SESSIONS + + op.create_table( + "customer_sessions", + sa.Column("token", sa.CHAR(length=64), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("customer_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("modified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("deleted_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["customer_id"], + ["customers.id"], + name=op.f("customer_sessions_customer_id_fkey"), + ondelete="cascade", + ), + sa.PrimaryKeyConstraint("id", name=op.f("customer_sessions_pkey")), + sa.UniqueConstraint("token", name=op.f("customer_sessions_token_key")), + ) + op.create_index( + op.f("ix_customer_sessions_created_at"), + "customer_sessions", + ["created_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_sessions_deleted_at"), + "customer_sessions", + ["deleted_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_sessions_expires_at"), + "customer_sessions", + ["expires_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_sessions_modified_at"), + "customer_sessions", + ["modified_at"], + unique=False, + ) + + # USER CUSTOMERS + + op.create_table( + "user_customers", + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("customer_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("modified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("deleted_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["customer_id"], + ["customers.id"], + name=op.f("user_customers_customer_id_fkey"), + ondelete="cascade", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("user_customers_user_id_fkey"), + ondelete="cascade", + ), + sa.PrimaryKeyConstraint("id", name=op.f("user_customers_pkey")), + sa.UniqueConstraint( + "user_id", "customer_id", name="user_customers_user_id_customer_id_key" + ), + ) + op.create_index( + op.f("ix_user_customers_created_at"), + "user_customers", + ["created_at"], + unique=False, + ) + op.create_index( + op.f("ix_user_customers_deleted_at"), + "user_customers", + ["deleted_at"], + unique=False, + ) + op.create_index( + op.f("ix_user_customers_modified_at"), + "user_customers", + ["modified_at"], + unique=False, + ) + + op.execute( + """ + INSERT INTO user_customers ( + id, + created_at, + user_id, + customer_id + ) + SELECT + uuid_generate_v4(), + now(), + customers.legacy_user_id, + customers.id + FROM customers + """ + ) + + # BENEFIT GRANTS + + op.add_column("benefit_grants", sa.Column("customer_id", sa.Uuid(), nullable=True)) + + op.execute( + """ + UPDATE benefit_grants + SET customer_id = customers.id + FROM customers, benefits + WHERE benefits.id = benefit_grants.benefit_id + AND customers.legacy_user_id = benefit_grants.user_id + AND customers.organization_id = benefits.organization_id + """ + ) + + op.alter_column("benefit_grants", "customer_id", nullable=False) + + op.drop_constraint("benefit_grants_sbu_key", "benefit_grants", type_="unique") + op.drop_index("ix_benefit_grants_user_id", table_name="benefit_grants") + op.create_unique_constraint( + "benefit_grants_sbc_key", + "benefit_grants", + ["subscription_id", "customer_id", "benefit_id"], + ) + op.create_index( + op.f("ix_benefit_grants_customer_id"), + "benefit_grants", + ["customer_id"], + unique=False, + ) + + op.execute( + "ALTER TABLE benefit_grants DROP CONSTRAINT IF EXISTS benefit_grants_user_id_fkey" + ) + op.execute( + "ALTER TABLE benefit_grants DROP CONSTRAINT IF EXISTS subscription_benefit_grants_user_id_fkey" + ) + op.create_foreign_key( + op.f("benefit_grants_customer_id_fkey"), + "benefit_grants", + "customers", + ["customer_id"], + ["id"], + ondelete="cascade", + ) + op.drop_column("benefit_grants", "user_id") + + # CHECKOUTS + + op.drop_constraint("checkouts_customer_id_fkey", "checkouts", type_="foreignkey") + op.execute( + """ + UPDATE checkouts + SET customer_id = customers.id + FROM customers, products + WHERE customers.legacy_user_id = checkouts.customer_id + AND products.id = checkouts.product_id + AND products.organization_id = customers.organization_id + """ + ) + op.execute( + """ + UPDATE checkouts + SET customer_id = NULL + WHERE customer_id NOT IN (SELECT id FROM customers)""" + ) + op.create_foreign_key( + op.f("checkouts_customer_id_fkey"), + "checkouts", + "customers", + ["customer_id"], + ["id"], + ondelete="set null", + ) + + # DOWNLOADABLES + + op.add_column("downloadables", sa.Column("customer_id", sa.Uuid(), nullable=True)) + op.execute( + """ + UPDATE downloadables + SET customer_id = customers.id + FROM customers, benefits + WHERE customers.legacy_user_id = downloadables.user_id + AND benefits.id = downloadables.benefit_id + AND customers.organization_id = benefits.organization_id + """ + ) + op.alter_column("downloadables", "customer_id", nullable=False) + op.drop_constraint( + "downloadables_user_id_file_id_benefit_id_key", "downloadables", type_="unique" + ) + op.drop_index("ix_downloadables_user_id", table_name="downloadables") + op.create_unique_constraint( + op.f("downloadables_customer_id_file_id_benefit_id_key"), + "downloadables", + ["customer_id", "file_id", "benefit_id"], + ) + op.create_index( + op.f("ix_downloadables_customer_id"), + "downloadables", + ["customer_id"], + unique=False, + ) + op.execute( + "ALTER TABLE downloadables DROP CONSTRAINT IF EXISTS file_permissions_user_id_fkey" + ) + op.execute( + "ALTER TABLE downloadables DROP CONSTRAINT IF EXISTS downloadables_user_id_fkey" + ) + op.create_foreign_key( + op.f("downloadables_customer_id_fkey"), + "downloadables", + "customers", + ["customer_id"], + ["id"], + ondelete="cascade", + ) + op.drop_column("downloadables", "user_id") + + # LICENSE KEYS + + op.add_column("license_keys", sa.Column("customer_id", sa.Uuid(), nullable=True)) + op.execute( + """ + UPDATE license_keys + SET customer_id = customers.id + FROM customers, benefits + WHERE customers.legacy_user_id = license_keys.user_id + AND benefits.id = license_keys.benefit_id + AND customers.organization_id = benefits.organization_id + """ + ) + op.alter_column("license_keys", "customer_id", nullable=False) + op.drop_index("ix_license_keys_user_id", table_name="license_keys") + op.create_index( + op.f("ix_license_keys_customer_id"), + "license_keys", + ["customer_id"], + unique=False, + ) + op.drop_constraint("license_keys_user_id_fkey", "license_keys", type_="foreignkey") + op.create_foreign_key( + op.f("license_keys_customer_id_fkey"), + "license_keys", + "customers", + ["customer_id"], + ["id"], + ondelete="cascade", + ) + op.drop_column("license_keys", "user_id") + + # ORDERS + + op.add_column("orders", sa.Column("customer_id", sa.Uuid(), nullable=True)) + op.execute( + """ + UPDATE orders + SET customer_id = customers.id + FROM customers, products + WHERE customers.legacy_user_id = orders.user_id + AND products.id = orders.product_id + AND products.organization_id = customers.organization_id + """ + ) + op.alter_column("orders", "customer_id", nullable=False) + op.execute("ALTER TABLE orders DROP CONSTRAINT IF EXISTS sales_user_id_fkey") + op.execute("ALTER TABLE orders DROP CONSTRAINT IF EXISTS orders_user_id_fkey") + op.create_foreign_key( + op.f("orders_customer_id_fkey"), "orders", "customers", ["customer_id"], ["id"] + ) + op.drop_column("orders", "user_id") + + # SUBSCRIPTIONS + + op.add_column("subscriptions", sa.Column("customer_id", sa.Uuid(), nullable=True)) + op.execute( + """ + UPDATE subscriptions + SET customer_id = customers.id + FROM customers, products + WHERE customers.legacy_user_id = subscriptions.user_id + AND products.id = subscriptions.product_id + AND products.organization_id = customers.organization_id + """ + ) + op.alter_column("subscriptions", "customer_id", nullable=False) + op.drop_index("ix_subscriptions_user_id", table_name="subscriptions") + op.create_index( + op.f("ix_subscriptions_customer_id"), + "subscriptions", + ["customer_id"], + unique=False, + ) + op.drop_constraint( + "subscriptions_user_id_fkey", "subscriptions", type_="foreignkey" + ) + op.create_foreign_key( + op.f("subscriptions_customer_id_fkey"), + "subscriptions", + "customers", + ["customer_id"], + ["id"], + ondelete="cascade", + ) + op.drop_column("subscriptions", "user_id") + + # TRANSACTIONS + + op.add_column( + "transactions", sa.Column("payment_customer_id", sa.Uuid(), nullable=True) + ) + op.execute( + """ + UPDATE transactions + SET payment_customer_id = customers.id + FROM customers + WHERE customers.legacy_user_id = transactions.payment_user_id + """ + ) + op.execute( + """ + UPDATE transactions + SET payment_user_id = NULL + WHERE payment_customer_id IS NOT NULL + """ + ) + op.create_index( + op.f("ix_transactions_payment_customer_id"), + "transactions", + ["payment_customer_id"], + unique=False, + ) + op.create_foreign_key( + op.f("transactions_payment_customer_id_fkey"), + "transactions", + "customers", + ["payment_customer_id"], + ["id"], + ondelete="set null", + ) + + # BenefitPreconditionErrorNotification + op.execute( + """ + DELETE FROM notifications + WHERE type = 'BenefitPreconditionErrorNotification' + """ + ) + + # License key activations enable_customer_admin flag + op.execute( + """ + UPDATE benefits + SET properties = jsonb_set(properties #- '{activations,enable_user_admin}', '{activations,enable_customer_admin}', properties #> '{activations,enable_user_admin}') + WHERE type = 'license_keys' + AND properties #> '{activations}' ? 'enable_user_admin' + """ + ) + + # Replace removed scopes + op.execute( + """ + WITH splitted_scope AS ( + SELECT id, regexp_split_to_table(scope, '\\s+') as s + FROM personal_access_tokens + ), user_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + ), has_write_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:subscriptions:write', + 'user:advertisement_campaigns:write' + ) + ), aggregated_scope AS ( + SELECT user_scope.id, + string_agg(splitted_scope.s, ' ') AS s, + EXISTS (SELECT 1 FROM has_write_scope WHERE has_write_scope.id = user_scope.id) as has_write + FROM splitted_scope + JOIN user_scope ON splitted_scope.id = user_scope.id + WHERE splitted_scope.s NOT IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + GROUP BY user_scope.id + ) + UPDATE personal_access_tokens + SET scope = CASE + WHEN aggregated_scope.has_write + THEN aggregated_scope.s || ' customer_portal:read customer_portal:write' + ELSE aggregated_scope.s || ' customer_portal:read' + END + FROM aggregated_scope + WHERE aggregated_scope.id = personal_access_tokens.id; + """ + ) + op.execute( + """ + WITH splitted_scope AS ( + SELECT id, regexp_split_to_table(scope, '\\s+') as s + FROM oauth2_tokens + ), user_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + ), has_write_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:subscriptions:write', + 'user:advertisement_campaigns:write' + ) + ), aggregated_scope AS ( + SELECT user_scope.id, + string_agg(splitted_scope.s, ' ') AS s, + EXISTS (SELECT 1 FROM has_write_scope WHERE has_write_scope.id = user_scope.id) as has_write + FROM splitted_scope + JOIN user_scope ON splitted_scope.id = user_scope.id + WHERE splitted_scope.s NOT IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + GROUP BY user_scope.id + ) + UPDATE oauth2_tokens + SET scope = CASE + WHEN aggregated_scope.has_write + THEN aggregated_scope.s || ' customer_portal:read customer_portal:write' + ELSE aggregated_scope.s || ' customer_portal:read' + END + FROM aggregated_scope + WHERE aggregated_scope.id = oauth2_tokens.id; + """ + ) + op.execute( + """ + WITH splitted_scope AS ( + SELECT id, regexp_split_to_table(scope, '\\s+') as s + FROM oauth2_grants + ), user_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + ), has_write_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:subscriptions:write', + 'user:advertisement_campaigns:write' + ) + ), aggregated_scope AS ( + SELECT user_scope.id, + string_agg(splitted_scope.s, ' ') AS s, + EXISTS (SELECT 1 FROM has_write_scope WHERE has_write_scope.id = user_scope.id) as has_write + FROM splitted_scope + JOIN user_scope ON splitted_scope.id = user_scope.id + WHERE splitted_scope.s NOT IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + GROUP BY user_scope.id + ) + UPDATE oauth2_grants + SET scope = CASE + WHEN aggregated_scope.has_write + THEN aggregated_scope.s || ' customer_portal:read customer_portal:write' + ELSE aggregated_scope.s || ' customer_portal:read' + END + FROM aggregated_scope + WHERE aggregated_scope.id = oauth2_grants.id; + """ + ) + op.execute( + """ + WITH splitted_scope AS ( + SELECT id, regexp_split_to_table((client_metadata::JSONB)->>'scope', '\\s+') as s + FROM oauth2_clients + ), user_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + ), has_write_scope AS ( + SELECT DISTINCT id + FROM splitted_scope + WHERE splitted_scope.s IN ( + 'user:subscriptions:write', + 'user:advertisement_campaigns:write' + ) + ), aggregated_scope AS ( + SELECT user_scope.id, + string_agg(splitted_scope.s, ' ') AS s, + EXISTS (SELECT 1 FROM has_write_scope WHERE has_write_scope.id = user_scope.id) as has_write + FROM splitted_scope + JOIN user_scope ON splitted_scope.id = user_scope.id + WHERE splitted_scope.s NOT IN ( + 'user:benefits:read', + 'user:orders:read', + 'user:subscriptions:read', + 'user:subscriptions:write', + 'user:downloadables:read', + 'user:advertisement_campaigns:read', + 'user:advertisement_campaigns:write', + 'user:license_keys:read' + ) + GROUP BY user_scope.id + ) + UPDATE oauth2_clients + SET client_metadata = jsonb_set( + client_metadata::jsonb, + '{scope}', + to_jsonb( + CASE + WHEN aggregated_scope.has_write + THEN aggregated_scope.s || ' customer_portal:read customer_portal:write' + ELSE aggregated_scope.s || ' customer_portal:read' + END + ), + true + )::text + FROM aggregated_scope + WHERE aggregated_scope.id = oauth2_clients.id; + """ + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + op.f("transactions_payment_customer_id_fkey"), + "transactions", + type_="foreignkey", + ) + op.drop_index( + op.f("ix_transactions_payment_customer_id"), table_name="transactions" + ) + op.drop_column("transactions", "payment_customer_id") + op.add_column( + "subscriptions", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False), + ) + op.drop_constraint( + op.f("subscriptions_customer_id_fkey"), "subscriptions", type_="foreignkey" + ) + op.create_foreign_key( + "subscriptions_user_id_fkey", + "subscriptions", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_index(op.f("ix_subscriptions_customer_id"), table_name="subscriptions") + op.create_index( + "ix_subscriptions_user_id", "subscriptions", ["user_id"], unique=False + ) + op.drop_column("subscriptions", "customer_id") + op.add_column( + "orders", sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False) + ) + op.drop_constraint(op.f("orders_customer_id_fkey"), "orders", type_="foreignkey") + op.create_foreign_key("orders_user_id_fkey", "orders", "users", ["user_id"], ["id"]) + op.drop_column("orders", "customer_id") + op.add_column( + "license_keys", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False), + ) + op.drop_constraint( + op.f("license_keys_customer_id_fkey"), "license_keys", type_="foreignkey" + ) + op.create_foreign_key( + "license_keys_user_id_fkey", + "license_keys", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_index(op.f("ix_license_keys_customer_id"), table_name="license_keys") + op.create_index( + "ix_license_keys_user_id", "license_keys", ["user_id"], unique=False + ) + op.drop_column("license_keys", "customer_id") + op.add_column( + "downloadables", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False), + ) + op.drop_constraint( + op.f("downloadables_customer_id_fkey"), "downloadables", type_="foreignkey" + ) + op.create_foreign_key( + "downloadables_user_id_fkey", + "downloadables", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_index(op.f("ix_downloadables_customer_id"), table_name="downloadables") + op.drop_constraint( + op.f("downloadables_customer_id_file_id_benefit_id_key"), + "downloadables", + type_="unique", + ) + op.create_index( + "ix_downloadables_user_id", "downloadables", ["user_id"], unique=False + ) + op.create_unique_constraint( + "downloadables_user_id_file_id_benefit_id_key", + "downloadables", + ["user_id", "file_id", "benefit_id"], + ) + op.drop_column("downloadables", "customer_id") + op.drop_constraint( + op.f("checkouts_customer_id_fkey"), "checkouts", type_="foreignkey" + ) + op.create_foreign_key( + "checkouts_customer_id_fkey", + "checkouts", + "users", + ["customer_id"], + ["id"], + ondelete="CASCADE", + ) + op.add_column( + "benefit_grants", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=False), + ) + op.drop_constraint( + op.f("benefit_grants_customer_id_fkey"), "benefit_grants", type_="foreignkey" + ) + op.create_foreign_key( + "benefit_grants_user_id_fkey", + "benefit_grants", + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_index(op.f("ix_benefit_grants_customer_id"), table_name="benefit_grants") + op.drop_constraint("benefit_grants_sbc_key", "benefit_grants", type_="unique") + op.create_index( + "ix_benefit_grants_user_id", "benefit_grants", ["user_id"], unique=False + ) + op.create_unique_constraint( + "benefit_grants_sbu_key", + "benefit_grants", + ["subscription_id", "user_id", "benefit_id"], + ) + op.drop_column("benefit_grants", "customer_id") + op.drop_index(op.f("ix_user_customers_modified_at"), table_name="user_customers") + op.drop_index(op.f("ix_user_customers_deleted_at"), table_name="user_customers") + op.drop_index(op.f("ix_user_customers_created_at"), table_name="user_customers") + op.drop_table("user_customers") + op.drop_index( + op.f("ix_customer_sessions_modified_at"), table_name="customer_sessions" + ) + op.drop_index( + op.f("ix_customer_sessions_expires_at"), table_name="customer_sessions" + ) + op.drop_index( + op.f("ix_customer_sessions_deleted_at"), table_name="customer_sessions" + ) + op.drop_index( + op.f("ix_customer_sessions_created_at"), table_name="customer_sessions" + ) + op.drop_table("customer_sessions") + op.drop_index( + "ix_customers_organization_id_email_case_insensitive", table_name="customers" + ) + op.drop_index(op.f("ix_customers_modified_at"), table_name="customers") + op.drop_index("ix_customers_email_case_insensitive", table_name="customers") + op.drop_index(op.f("ix_customers_deleted_at"), table_name="customers") + op.drop_index(op.f("ix_customers_created_at"), table_name="customers") + op.drop_table("customers") + # ### end Alembic commands ### From 8d7f375e555a7810cf34cae83c36f0d0c4c701d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Tue, 10 Dec 2024 16:13:10 +0100 Subject: [PATCH 23/47] server: wire customer router --- server/polar/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/polar/api.py b/server/polar/api.py index 38ad86f917..7790ab9d28 100644 --- a/server/polar/api.py +++ b/server/polar/api.py @@ -9,6 +9,7 @@ from polar.checkout.legacy.endpoints import router as checkout_legacy_router from polar.checkout_link.endpoints import router as checkout_link_router from polar.custom_field.endpoints import router as custom_field_router +from polar.customer.endpoints import router as customer_router from polar.customer_portal.endpoints import router as customer_portal_router from polar.dashboard.endpoints import router as dashboard_router from polar.discount.endpoints import router as discount_router @@ -127,5 +128,7 @@ router.include_router(embed_router) # /discounts router.include_router(discount_router) +# /customers +router.include_router(customer_router) # /customer-portal router.include_router(customer_portal_router) From 714a25eb57827787216bcc05c4e1e4a72e9cb0b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 09:01:23 +0100 Subject: [PATCH 24/47] server/customer_portal: add checkout_id filter to benefit grant list endpoint --- .../endpoints/benefit_grant.py | 4 +++ .../customer_portal/service/benefit_grant.py | 25 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/server/polar/customer_portal/endpoints/benefit_grant.py b/server/polar/customer_portal/endpoints/benefit_grant.py index 2cf22e42a2..a3faa07169 100644 --- a/server/polar/customer_portal/endpoints/benefit_grant.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -53,6 +53,9 @@ async def list( organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." ), + checkout_id: MultipleQueryFilter[UUID4] | None = Query( + None, title="CheckoutID Filter", description="Filter by checkout ID." + ), order_id: MultipleQueryFilter[UUID4] | None = Query( None, title="OrderID Filter", description="Filter by order ID." ), @@ -67,6 +70,7 @@ async def list( auth_subject, type=type, organization_id=organization_id, + checkout_id=checkout_id, order_id=order_id, subscription_id=subscription_id, pagination=pagination, diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py index ddd03da138..ffbaf9fde4 100644 --- a/server/polar/customer_portal/service/benefit_grant.py +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -3,7 +3,7 @@ from enum import StrEnum from typing import Any, cast -from sqlalchemy import Select, UnaryExpression, asc, desc, select +from sqlalchemy import Select, UnaryExpression, asc, desc, or_, select from sqlalchemy.orm import contains_eager from polar.auth.models import AuthSubject, is_customer, is_user @@ -17,7 +17,9 @@ Benefit, BenefitGrant, Customer, + Order, Organization, + Subscription, User, UserCustomer, ) @@ -46,6 +48,7 @@ async def list( type: Sequence[BenefitType] | None = None, benefit_id: Sequence[uuid.UUID] | None = None, organization_id: Sequence[uuid.UUID] | None = None, + checkout_id: Sequence[uuid.UUID] | None = None, order_id: Sequence[uuid.UUID] | None = None, subscription_id: Sequence[uuid.UUID] | None = None, pagination: PaginationParams, @@ -64,6 +67,26 @@ async def list( if organization_id is not None: statement = statement.where(Benefit.organization_id.in_(organization_id)) + if checkout_id is not None: + statement = ( + statement.join( + Subscription, + onclause=Subscription.id == BenefitGrant.subscription_id, + isouter=True, + ) + .join( + Order, + onclause=Order.id == BenefitGrant.order_id, + isouter=True, + ) + .where( + or_( + Subscription.checkout_id.in_(checkout_id), + Order.checkout_id.in_(checkout_id), + ) + ) + ) + if order_id is not None: statement = statement.where(BenefitGrant.order_id.in_(order_id)) From 4174b0c82b09e1040ecd5d042c0d34ec9e107318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 09:16:15 +0100 Subject: [PATCH 25/47] server/customer_portal: fix scope for CustomerPortalRead authenticator --- server/polar/customer_portal/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/polar/customer_portal/auth.py b/server/polar/customer_portal/auth.py index ab699a84a8..03ac9710d0 100644 --- a/server/polar/customer_portal/auth.py +++ b/server/polar/customer_portal/auth.py @@ -10,7 +10,7 @@ required_scopes={ Scope.web_default, Scope.customer_portal_read, - Scope.custom_fields_write, + Scope.customer_portal_write, }, allowed_subjects={User, Customer}, ) From 3639cc67dcc0a7c5c13bc8306970efa91a6151d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 09:22:28 +0100 Subject: [PATCH 26/47] server/customer_portal: tweak benefit grant schema --- .../endpoints/benefit_grant.py | 21 +++-- .../customer_portal/schemas/benefit_grant.py | 91 ++++++++++--------- .../customer_portal/service/benefit_grant.py | 14 +-- 3 files changed, 69 insertions(+), 57 deletions(-) diff --git a/server/polar/customer_portal/endpoints/benefit_grant.py b/server/polar/customer_portal/endpoints/benefit_grant.py index a3faa07169..906c4433ed 100644 --- a/server/polar/customer_portal/endpoints/benefit_grant.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -16,8 +16,11 @@ from polar.routing import APIRouter from .. import auth -from ..schemas.benefit_grant import BenefitGrant as BenefitGrantSchema -from ..schemas.benefit_grant import BenefitGrantAdapter, BenefitGrantUpdate +from ..schemas.benefit_grant import ( + CustomerBenefitGrant, + CustomerBenefitGrantAdapter, + CustomerBenefitGrantUpdate, +) from ..service.benefit_grant import CustomerBenefitGrantSortProperty from ..service.benefit_grant import ( customer_benefit_grant as customer_benefit_grant_service, @@ -41,7 +44,9 @@ @router.get( - "/", summary="List Benefit Grants", response_model=ListResource[BenefitGrantSchema] + "/", + summary="List Benefit Grants", + response_model=ListResource[CustomerBenefitGrant], ) async def list( auth_subject: auth.CustomerPortalRead, @@ -63,7 +68,7 @@ async def list( None, title="SubscriptionID Filter", description="Filter by subscription ID." ), session: AsyncSession = Depends(get_db_session), -) -> ListResource[BenefitGrantSchema]: +) -> ListResource[CustomerBenefitGrant]: """List benefits grants of the authenticated customer or user.""" results, count = await customer_benefit_grant_service.list( session, @@ -78,7 +83,7 @@ async def list( ) return ListResource.from_paginated_results( - [BenefitGrantAdapter.validate_python(result) for result in results], + [CustomerBenefitGrantAdapter.validate_python(result) for result in results], count, pagination, ) @@ -87,7 +92,7 @@ async def list( @router.get( "/{id}", summary="Get Benefit Grant", - response_model=BenefitGrantSchema, + response_model=CustomerBenefitGrant, responses={404: BenefitGrantNotFound}, ) async def get( @@ -109,7 +114,7 @@ async def get( @router.get( "/{id}", summary="Update Benefit Grant", - response_model=BenefitGrantSchema, + response_model=CustomerBenefitGrant, responses={ 200: {"description": "Benefit grant updated."}, 403: { @@ -121,7 +126,7 @@ async def get( ) async def update( id: BenefitGrantID, - benefit_grant_update: BenefitGrantUpdate, + benefit_grant_update: CustomerBenefitGrantUpdate, auth_subject: auth.CustomerPortalWrite, session: AsyncSession = Depends(get_db_session), ) -> BenefitGrant: diff --git a/server/polar/customer_portal/schemas/benefit_grant.py b/server/polar/customer_portal/schemas/benefit_grant.py index 48771f7b8f..bb49c167c6 100644 --- a/server/polar/customer_portal/schemas/benefit_grant.py +++ b/server/polar/customer_portal/schemas/benefit_grant.py @@ -10,9 +10,15 @@ BenefitDownloadablesSubscriber, BenefitGitHubRepositorySubscriber, BenefitLicenseKeysSubscriber, - BenefitSubscriber, ) -from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema +from polar.kit.schemas import ( + ClassName, + IDSchema, + MergeJSONSchema, + Schema, + SetSchemaReference, + TimestampedSchema, +) from polar.models.benefit import BenefitType from polar.models.benefit_grant import ( BenefitGrantAdsProperties, @@ -25,7 +31,7 @@ from polar.models.customer import CustomerOAuthPlatform -class BenefitGrantBase(IDSchema, TimestampedSchema): +class CustomerBenefitGrantBase(IDSchema, TimestampedSchema): granted_at: datetime | None revoked_at: datetime | None customer_id: UUID4 @@ -36,105 +42,104 @@ class BenefitGrantBase(IDSchema, TimestampedSchema): is_revoked: bool -BenefitCustomer = Annotated[ - BenefitSubscriber, - MergeJSONSchema({"title": "BenefitCustomer"}), -] - - -class BenefitGrantDiscord(BenefitGrantBase): +class CustomerBenefitGrantDiscord(CustomerBenefitGrantBase): benefit: BenefitDiscordSubscriber properties: BenefitGrantDiscordProperties -class BenefitGrantGitHubRepository(BenefitGrantBase): +class CustomerBenefitGrantGitHubRepository(CustomerBenefitGrantBase): benefit: BenefitGitHubRepositorySubscriber properties: BenefitGrantGitHubRepositoryProperties -class BenefitGrantDownloadables(BenefitGrantBase): +class CustomerBenefitGrantDownloadables(CustomerBenefitGrantBase): benefit: BenefitDownloadablesSubscriber properties: BenefitGrantDownloadablesProperties -class BenefitGrantLicenseKeys(BenefitGrantBase): +class CustomerBenefitGrantLicenseKeys(CustomerBenefitGrantBase): benefit: BenefitLicenseKeysSubscriber properties: BenefitGrantLicenseKeysProperties -class BenefitGrantAds(BenefitGrantBase): +class CustomerBenefitGrantAds(CustomerBenefitGrantBase): benefit: BenefitAdsSubscriber properties: BenefitGrantAdsProperties -class BenefitGrantCustom(BenefitGrantBase): +class CustomerBenefitGrantCustom(CustomerBenefitGrantBase): benefit: BenefitCustomSubscriber properties: BenefitGrantCustomProperties -BenefitGrant = Annotated[ - BenefitGrantDiscord - | BenefitGrantGitHubRepository - | BenefitGrantDownloadables - | BenefitGrantLicenseKeys - | BenefitGrantAds - | BenefitGrantCustom, - MergeJSONSchema({"title": "BenefitGrant"}), +CustomerBenefitGrant = Annotated[ + CustomerBenefitGrantDiscord + | CustomerBenefitGrantGitHubRepository + | CustomerBenefitGrantDownloadables + | CustomerBenefitGrantLicenseKeys + | CustomerBenefitGrantAds + | CustomerBenefitGrantCustom, + SetSchemaReference("CustomerBenefitGrant"), + MergeJSONSchema({"title": "CustomerBenefitGrant"}), + ClassName("CustomerBenefitGrant"), ] -BenefitGrantAdapter: TypeAdapter[BenefitGrant] = TypeAdapter(BenefitGrant) +CustomerBenefitGrantAdapter: TypeAdapter[CustomerBenefitGrant] = TypeAdapter( + CustomerBenefitGrant +) -class BenefitGrantUpdateBase(Schema): +class CustomerBenefitGrantUpdateBase(Schema): benefit_type: BenefitType -class BenefitGrantDiscordPropertiesUpdate(TypedDict): +class CustomerBenefitGrantDiscordPropertiesUpdate(TypedDict): account_id: str -class BenefitGrantDiscordUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantDiscordUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.discord] - properties: BenefitGrantDiscordPropertiesUpdate + properties: CustomerBenefitGrantDiscordPropertiesUpdate def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.discord]: return CustomerOAuthPlatform.discord -class BenefitGrantGitHubRepositoryPropertiesUpdate(TypedDict): +class CustomerBenefitGrantGitHubRepositoryPropertiesUpdate(TypedDict): account_id: str -class BenefitGrantGitHubRepositoryUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantGitHubRepositoryUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.github_repository] - properties: BenefitGrantGitHubRepositoryPropertiesUpdate + properties: CustomerBenefitGrantGitHubRepositoryPropertiesUpdate def get_oauth_platform(self) -> Literal[CustomerOAuthPlatform.github]: return CustomerOAuthPlatform.github -class BenefitGrantDownloadablesUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantDownloadablesUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.downloadables] -class BenefitGrantLicenseKeysUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantLicenseKeysUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.license_keys] -class BenefitGrantAdsUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantAdsUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.ads] -class BenefitGrantCustomUpdate(BenefitGrantUpdateBase): +class CustomerBenefitGrantCustomUpdate(CustomerBenefitGrantUpdateBase): benefit_type: Literal[BenefitType.custom] -BenefitGrantUpdate = Annotated[ - BenefitGrantDiscordUpdate - | BenefitGrantGitHubRepositoryUpdate - | BenefitGrantDownloadablesUpdate - | BenefitGrantLicenseKeysUpdate - | BenefitGrantAdsUpdate - | BenefitGrantCustomUpdate, - MergeJSONSchema({"title": "BenefitGrantUpdate"}), +CustomerBenefitGrantUpdate = Annotated[ + CustomerBenefitGrantDiscordUpdate + | CustomerBenefitGrantGitHubRepositoryUpdate + | CustomerBenefitGrantDownloadablesUpdate + | CustomerBenefitGrantLicenseKeysUpdate + | CustomerBenefitGrantAdsUpdate + | CustomerBenefitGrantCustomUpdate, + SetSchemaReference("CustomerBenefitGrantUpdate"), + MergeJSONSchema({"title": "CustomerBenefitGrantUpdate"}), Discriminator("benefit_type"), ] diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py index ffbaf9fde4..af34e2ff81 100644 --- a/server/polar/customer_portal/service/benefit_grant.py +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -27,9 +27,9 @@ from polar.worker import enqueue_job from ..schemas.benefit_grant import ( - BenefitGrantDiscordUpdate, - BenefitGrantGitHubRepositoryUpdate, - BenefitGrantUpdate, + CustomerBenefitGrantDiscordUpdate, + CustomerBenefitGrantGitHubRepositoryUpdate, + CustomerBenefitGrantUpdate, ) @@ -125,7 +125,7 @@ async def update( self, session: AsyncSession, benefit_grant: BenefitGrant, - benefit_grant_update: BenefitGrantUpdate, + benefit_grant_update: CustomerBenefitGrantUpdate, ) -> BenefitGrant: if benefit_grant.is_revoked: raise NotPermitted("Cannot update a revoked benefit grant.") @@ -142,8 +142,10 @@ async def update( ] ) - if isinstance(benefit_grant_update, BenefitGrantDiscordUpdate) or isinstance( - benefit_grant_update, BenefitGrantGitHubRepositoryUpdate + if isinstance( + benefit_grant_update, CustomerBenefitGrantDiscordUpdate + ) or isinstance( + benefit_grant_update, CustomerBenefitGrantGitHubRepositoryUpdate ): account_id = benefit_grant_update.properties["account_id"] platform = benefit_grant_update.get_oauth_platform() From 11e58415ce1efffeae84c48f5df9336f322389e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 14:14:55 +0100 Subject: [PATCH 27/47] server/benefit: remove legacy configuration for GitHub repository --- .../benefit/benefits/github_repository.py | 65 ------------------- server/polar/benefit/schemas.py | 14 ++-- server/polar/models/benefit.py | 5 -- 3 files changed, 4 insertions(+), 80 deletions(-) diff --git a/server/polar/benefit/benefits/github_repository.py b/server/polar/benefit/benefits/github_repository.py index fd3fad07e5..4e72c461d5 100644 --- a/server/polar/benefit/benefits/github_repository.py +++ b/server/polar/benefit/benefits/github_repository.py @@ -5,7 +5,6 @@ from githubkit.exception import RateLimitExceeded, RequestError, RequestTimeout from polar.auth.models import AuthSubject, is_organization, is_user -from polar.authz.service import AccessType, Authz from polar.integrations.github import client as github from polar.integrations.github import types from polar.integrations.github_repository_benefit.service import ( @@ -20,7 +19,6 @@ from polar.models.benefit_grant import BenefitGrantGitHubRepositoryProperties from polar.models.customer import CustomerOAuthPlatform from polar.posthog import posthog -from polar.repository.service import repository as repository_service from .base import ( BenefitActionRequiredError, @@ -130,10 +128,6 @@ async def revoke( customer_id=str(customer.id), ) - if benefit.properties["repository_id"]: - bound_logger.info("skipping revoke for old version of this benefit type") - return {} - client = await self._get_github_app_client(benefit) repository_owner = benefit.properties["repository_owner"] @@ -302,65 +296,6 @@ async def validate_properties( }, ) - async def _validate_properties_repository_id( - self, user: User, properties: dict[str, Any] - ) -> BenefitGitHubRepositoryProperties: - repository_id = properties["repository_id"] - - repository = await repository_service.get( - self.session, repository_id, load_organization=True - ) - - if repository is None: - raise BenefitPropertiesValidationError( - [ - { - "type": "invalid_repository", - "msg": "This repository does not exist.", - "loc": ("repository_id",), - "input": repository_id, - } - ] - ) - - authz = Authz(self.session) - if not await authz.can(user, AccessType.write, repository): - raise BenefitPropertiesValidationError( - [ - { - "type": "no_repository_acccess", - "msg": "You don't have access to this repository.", - "loc": ("repository_id",), - "input": repository_id, - } - ] - ) - - if posthog.client and not posthog.client.feature_enabled( - "github-benefit-personal-org", user.posthog_distinct_id - ): - if repository.organization.is_personal: - raise BenefitPropertiesValidationError( - [ - { - "type": "personal_organization_repository", - "msg": "For security reasons, " - "repositories on personal organizations are not supported.", - "loc": ("repository_id",), - "input": repository_id, - } - ] - ) - - return cast( - BenefitGitHubRepositoryProperties, - { - **properties, - "repository_owner": repository.organization.name, - "repository_name": repository.name, - }, - ) - async def _get_invitation( self, client: github.GitHub[Any], diff --git a/server/polar/benefit/schemas.py b/server/polar/benefit/schemas.py index 0e57844de5..afe3a1990c 100644 --- a/server/polar/benefit/schemas.py +++ b/server/polar/benefit/schemas.py @@ -170,15 +170,11 @@ class BenefitGitHubRepositoryCreateProperties(Schema): Properties to create a benefit of type `github_repository`. """ - # For benefits created before 2014-13-15 repository_id will be set - # no new benefits of this type are allowed to be created - repository_id: UUID4 | None = None - # For benefits created after 2014-13-15 both repository_owner and repository_name will be set - repository_owner: str | None = Field( - None, description="The owner of the repository.", examples=["polarsource"] + repository_owner: str = Field( + description="The owner of the repository.", examples=["polarsource"] ) - repository_name: str | None = Field( - None, description="The name of the repository.", examples=["private_repo"] + repository_name: str = Field( + description="The name of the repository.", examples=["private_repo"] ) permission: Permission @@ -188,8 +184,6 @@ class BenefitGitHubRepositoryProperties(Schema): Properties for a benefit of type `github_repository`. """ - # Is set to None for all benefits created after 2024-03-15 - repository_id: UUID4 | None repository_owner: RepositoryOwner repository_name: RepositoryName permission: Permission diff --git a/server/polar/models/benefit.py b/server/polar/models/benefit.py index 6de7ca8e5c..b862091a01 100644 --- a/server/polar/models/benefit.py +++ b/server/polar/models/benefit.py @@ -61,11 +61,6 @@ class BenefitAdsProperties(BenefitProperties): class BenefitGitHubRepositoryProperties(BenefitProperties): - # repository_id was set previously (before 2024-13-15), for benefits using the "main" - # Polar GitHub App for granting benefits. Benefits created after this date are using - # the "Polar Repository Benefit" GitHub App, and only uses the repository_owner - # and repository_name fields. - repository_id: UUID | None repository_owner: str repository_name: str permission: Literal["pull", "triage", "push", "maintain", "admin"] From 3ae7e5491480a1b1e2a421fde104606cd1c7a1f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 11:01:46 +0100 Subject: [PATCH 28/47] server/customer_portal: don't require auth when downloading file We already have a token at hand which should be enough to prove access --- .../endpoints/downloadables.py | 3 +- .../customer_portal/service/downloadables.py | 13 +++--- server/polar/license_key/service.py | 14 +++++-- .../endpoints/test_downloadables.py | 42 ------------------- 4 files changed, 16 insertions(+), 56 deletions(-) diff --git a/server/polar/customer_portal/endpoints/downloadables.py b/server/polar/customer_portal/endpoints/downloadables.py index 032ea70f4d..68cc552b8e 100644 --- a/server/polar/customer_portal/endpoints/downloadables.py +++ b/server/polar/customer_portal/endpoints/downloadables.py @@ -60,11 +60,10 @@ async def list( ) async def get( token: str, - auth_subject: auth.CustomerPortalRead, session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: downloadable = await downloadable_service.get_from_token_or_raise( - session, auth_subject, token=token + session, token=token ) signed = downloadable_service.generate_download_schema(downloadable) return RedirectResponse(signed.file.download.url, 302) diff --git a/server/polar/customer_portal/service/downloadables.py b/server/polar/customer_portal/service/downloadables.py index ca212b5aad..b573b912ea 100644 --- a/server/polar/customer_portal/service/downloadables.py +++ b/server/polar/customer_portal/service/downloadables.py @@ -4,7 +4,7 @@ import structlog from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer -from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import contains_eager, joinedload from polar.auth.models import AuthSubject, is_customer, is_user from polar.config import settings @@ -190,10 +190,7 @@ def create_download_token(self, downloadable: Downloadable) -> DownloadableURL: return DownloadableURL(url=redirect_to, expires_at=expires_at) async def get_from_token_or_raise( - self, - session: AsyncSession, - auth_subject: AuthSubject[User | Customer], - token: str, + self, session: AsyncSession, token: str ) -> Downloadable: try: unpacked = token_serializer.loads( @@ -207,9 +204,9 @@ async def get_from_token_or_raise( except KeyError: raise BadRequest() - statement = self._get_base_query(auth_subject).where(Downloadable.id == id) - res = await session.execute(statement) - downloadable = res.scalars().one_or_none() + downloadable = await self.get( + session, id, options=(joinedload(Downloadable.file),) + ) if not downloadable: raise ResourceNotFound() diff --git a/server/polar/license_key/service.py b/server/polar/license_key/service.py index 4c81578c3c..bb4e8deb49 100644 --- a/server/polar/license_key/service.py +++ b/server/polar/license_key/service.py @@ -453,8 +453,12 @@ async def get_customer_list( benefit_id: UUID | None = None, organization_ids: Sequence[UUID] | None = None, ) -> tuple[Sequence[LicenseKey], int]: - query = self._get_select_customer_base(auth_subject).order_by( - LicenseKey.created_at.asc() + query = ( + self._get_select_customer_base(auth_subject) + .order_by(LicenseKey.created_at.asc()) + .options( + joinedload(LicenseKey.benefit), + ) ) if organization_ids: @@ -471,8 +475,10 @@ async def get_customer_license_key( auth_subject: AuthSubject[User | Customer], license_key_id: UUID, ) -> LicenseKey | None: - query = self._get_select_customer_base(auth_subject).where( - LicenseKey.id == license_key_id + query = ( + self._get_select_customer_base(auth_subject) + .where(LicenseKey.id == license_key_id) + .options(joinedload(LicenseKey.activations), joinedload(LicenseKey.benefit)) ) result = await session.execute(query) return result.unique().scalar_one_or_none() diff --git a/server/tests/customer_portal/endpoints/test_downloadables.py b/server/tests/customer_portal/endpoints/test_downloadables.py index 58d8aa8616..a69747fc35 100644 --- a/server/tests/customer_portal/endpoints/test_downloadables.py +++ b/server/tests/customer_portal/endpoints/test_downloadables.py @@ -26,48 +26,6 @@ async def test_anonymous_list_401s(self, client: AsyncClient) -> None: async def test_anonymous_download_401s(self, client: AsyncClient) -> None: response = await client.get("/v1/customer-portal/downloadables/i-am-hacker") - assert response.status_code == 401 - - @pytest.mark.auth(AuthSubjectFixture(subject="customer")) - async def test_revoked_404s( - self, - session: AsyncSession, - redis: Redis, - client: AsyncClient, - save_fixture: SaveFixture, - customer: Customer, - organization: Organization, - product: Product, - uploaded_logo_jpg: File, - ) -> None: - benefit, granted = await TestDownloadable.create_benefit_and_grant( - session, - redis, - save_fixture, - customer=customer, - organization=organization, - product=product, - properties=BenefitDownloadablesCreateProperties( - files=[uploaded_logo_jpg.id] - ), - ) - - # List of downloadables - response = await client.get("/v1/customer-portal/downloadables/") - assert response.status_code == 200 - data = response.json() - downloadable_list = data["items"] - pagination = data["pagination"] - assert pagination["total_count"] == 1 - assert len(downloadable_list) == 1 - downloadable = downloadable_list[0] - polar_download_url = downloadable["file"]["download"]["url"] - - # Revoke the benefit - await TestDownloadable.run_revoke_task(session, redis, benefit, customer) - - # Polar download endpoint will now 404 - response = await client.get(polar_download_url, follow_redirects=False) assert response.status_code == 404 @pytest.mark.auth(AuthSubjectFixture(subject="customer")) From f63db51c525aa70dc8d460be521abd7d75290d07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 11:47:11 +0100 Subject: [PATCH 29/47] server/customer_portal: add benefit_id filter on grants list --- server/polar/customer_portal/endpoints/benefit_grant.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/polar/customer_portal/endpoints/benefit_grant.py b/server/polar/customer_portal/endpoints/benefit_grant.py index 906c4433ed..60a512395d 100644 --- a/server/polar/customer_portal/endpoints/benefit_grant.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -55,6 +55,9 @@ async def list( type: MultipleQueryFilter[BenefitType] | None = Query( None, title="BenefitType Filter", description="Filter by benefit type." ), + benefit_id: MultipleQueryFilter[UUID4] | None = Query( + None, title="BenefitID Filter", description="Filter by benefit ID." + ), organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." ), @@ -74,6 +77,7 @@ async def list( session, auth_subject, type=type, + benefit_id=benefit_id, organization_id=organization_id, checkout_id=checkout_id, order_id=order_id, From def8b25754ed0da0ea291cbe596009c2c397ff4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 14:22:44 +0100 Subject: [PATCH 30/47] server/customer_portal: filter out revoked benefit grants --- .../customer_portal/service/benefit_grant.py | 1 + .../service/test_benefit_grant.py | 45 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/server/polar/customer_portal/service/benefit_grant.py b/server/polar/customer_portal/service/benefit_grant.py index af34e2ff81..ac3a8226cc 100644 --- a/server/polar/customer_portal/service/benefit_grant.py +++ b/server/polar/customer_portal/service/benefit_grant.py @@ -193,6 +193,7 @@ def _get_readable_benefit_grant_statement( .join(Organization, onclause=Benefit.organization_id == Organization.id) .where( BenefitGrant.deleted_at.is_(None), + BenefitGrant.is_revoked.is_(False), ) .options( contains_eager(BenefitGrant.benefit).options( diff --git a/server/tests/customer_portal/service/test_benefit_grant.py b/server/tests/customer_portal/service/test_benefit_grant.py index 38561e6bba..8461875560 100644 --- a/server/tests/customer_portal/service/test_benefit_grant.py +++ b/server/tests/customer_portal/service/test_benefit_grant.py @@ -51,6 +51,7 @@ async def test_customer( subscription: Subscription, benefit_organization: Benefit, benefit_organization_second: Benefit, + benefit_organization_third: Benefit, customer: Customer, ) -> None: await create_benefit_grant( @@ -65,6 +66,14 @@ async def test_customer( save_fixture, customer, benefit_organization_second, + granted=None, + subscription=subscription, + ) + + await create_benefit_grant( + save_fixture, + customer, + benefit_organization_third, granted=False, subscription=subscription, ) @@ -158,8 +167,42 @@ async def test_customer_revoked( session, auth_subject, grant.id ) + assert result is None + + @pytest.mark.auth(AuthSubjectFixture(subject="customer")) + async def test_customer_pending( + self, + auth_subject: AuthSubject[Customer], + save_fixture: SaveFixture, + session: AsyncSession, + subscription: Subscription, + benefit_organization: Benefit, + customer: Customer, + customer_second: Customer, + ) -> None: + customer_grant = await create_benefit_grant( + save_fixture, + customer, + benefit_organization, + granted=None, + subscription=subscription, + ) + await create_benefit_grant( + save_fixture, + customer_second, + benefit_organization, + granted=True, + subscription=subscription, + ) + + session.expunge_all() + + result = await customer_benefit_grant_service.get_by_id( + session, auth_subject, customer_grant.id + ) + assert result is not None - assert result.is_revoked + assert result.id == customer_grant.id @pytest.mark.auth(AuthSubjectFixture(subject="customer")) async def test_customer_granted( From 1616efc3667d076409e44103b1f6a54e45fb822b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 14:29:32 +0100 Subject: [PATCH 31/47] server/storefront: tweak customer output --- server/polar/storefront/endpoints.py | 9 ++++++++- server/polar/storefront/schemas.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/server/polar/storefront/endpoints.py b/server/polar/storefront/endpoints.py index f9a0fa4247..aa627e778d 100644 --- a/server/polar/storefront/endpoints.py +++ b/server/polar/storefront/endpoints.py @@ -47,7 +47,14 @@ async def get(slug: str, session: AsyncSession = Depends(get_db_session)) -> Sto "donation_product": donation_product, "customers": { "total": total, - "customers": customers, + "customers": [ + { + "name": customer.name[0] + if customer.name + else customer.email[0], + } + for customer in customers + ], }, } ) diff --git a/server/polar/storefront/schemas.py b/server/polar/storefront/schemas.py index fed6cac743..3de30284ee 100644 --- a/server/polar/storefront/schemas.py +++ b/server/polar/storefront/schemas.py @@ -21,7 +21,8 @@ class ProductStorefront(ProductBase): ) -class Customer(Schema): ... +class Customer(Schema): + name: str class Customers(Schema): From fb4d094fd086c5350a4f44fc2f302a021d2692a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 12 Dec 2024 10:54:05 +0100 Subject: [PATCH 32/47] server/customer_portal: tweak OAuth accounts endpoints so it works better with User/Customer sessions --- server/polar/customer/service.py | 17 +++- server/polar/customer_portal/auth.py | 6 +- .../customer_portal/endpoints/__init__.py | 2 + .../endpoints/benefit_grant.py | 2 +- .../customer_portal/endpoints/customer.py | 49 +++++++++++ .../endpoints/oauth_accounts.py | 84 ++++++++++++++----- .../polar/customer_portal/schemas/customer.py | 17 ++++ .../customer_portal/schemas/oauth_accounts.py | 5 ++ server/polar/models/customer.py | 17 +++- 9 files changed, 170 insertions(+), 29 deletions(-) create mode 100644 server/polar/customer_portal/endpoints/customer.py create mode 100644 server/polar/customer_portal/schemas/customer.py create mode 100644 server/polar/customer_portal/schemas/oauth_accounts.py diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 7bfb04b9a8..5a1f0a6156 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -11,7 +11,7 @@ from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader from polar.kit.sorting import Sorting -from polar.models import Customer, Organization, User, UserOrganization +from polar.models import Customer, Organization, User, UserCustomer, UserOrganization from polar.organization.resolver import get_payload_organization from polar.postgres import AsyncSession @@ -162,6 +162,21 @@ async def get_by_email_and_organization( result = await session.execute(statement) return result.scalar_one_or_none() + async def get_by_id_and_user( + self, session: AsyncSession, id: uuid.UUID, user: User + ) -> Customer | None: + statement = ( + select(Customer) + .join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id) + .where( + Customer.deleted_at.is_(None), + Customer.id == id, + UserCustomer.user_id == user.id, + ) + ) + result = await session.execute(statement) + return result.scalar_one_or_none() + async def get_by_stripe_customer_id( self, session: AsyncSession, stripe_customer_id: str ) -> Customer | None: diff --git a/server/polar/customer_portal/auth.py b/server/polar/customer_portal/auth.py index 03ac9710d0..cebe406d63 100644 --- a/server/polar/customer_portal/auth.py +++ b/server/polar/customer_portal/auth.py @@ -3,7 +3,7 @@ from fastapi import Depends from polar.auth.dependencies import Authenticator -from polar.auth.models import AuthSubject, Customer, User +from polar.auth.models import Anonymous, AuthSubject, Customer, User from polar.auth.scope import Scope _CustomerPortalRead = Authenticator( @@ -28,8 +28,8 @@ _CustomerPortalOAuthAccount = Authenticator( required_scopes={Scope.web_default, Scope.customer_portal_write}, - allowed_subjects={Customer}, + allowed_subjects={User, Anonymous}, ) CustomerPortalOAuthAccount = Annotated[ - AuthSubject[Customer], Depends(_CustomerPortalOAuthAccount) + AuthSubject[User | Anonymous], Depends(_CustomerPortalOAuthAccount) ] diff --git a/server/polar/customer_portal/endpoints/__init__.py b/server/polar/customer_portal/endpoints/__init__.py index 5bf61b181b..f365ebd35a 100644 --- a/server/polar/customer_portal/endpoints/__init__.py +++ b/server/polar/customer_portal/endpoints/__init__.py @@ -1,6 +1,7 @@ from polar.routing import APIRouter from .benefit_grant import router as benefit_grant_router +from .customer import router as customer_router from .downloadables import router as downloadables_router from .license_keys import router as license_keys_router from .oauth_accounts import router as oauth_accounts_router @@ -10,6 +11,7 @@ router = APIRouter(prefix="/customer-portal", tags=["customer_portal"]) router.include_router(benefit_grant_router) +router.include_router(customer_router) router.include_router(downloadables_router) router.include_router(license_keys_router) router.include_router(oauth_accounts_router) diff --git a/server/polar/customer_portal/endpoints/benefit_grant.py b/server/polar/customer_portal/endpoints/benefit_grant.py index 60a512395d..5e1300f76b 100644 --- a/server/polar/customer_portal/endpoints/benefit_grant.py +++ b/server/polar/customer_portal/endpoints/benefit_grant.py @@ -115,7 +115,7 @@ async def get( return benefit_grant -@router.get( +@router.patch( "/{id}", summary="Update Benefit Grant", response_model=CustomerBenefitGrant, diff --git a/server/polar/customer_portal/endpoints/customer.py b/server/polar/customer_portal/endpoints/customer.py new file mode 100644 index 0000000000..7c4a0744ab --- /dev/null +++ b/server/polar/customer_portal/endpoints/customer.py @@ -0,0 +1,49 @@ +from typing import Annotated + +from fastapi import Depends, Path +from pydantic import UUID4 + +from polar.auth.models import is_customer, is_user +from polar.customer.service import customer as customer_service +from polar.exceptions import ResourceNotFound +from polar.models import Customer +from polar.openapi import APITag +from polar.postgres import AsyncSession, get_db_session +from polar.routing import APIRouter + +from .. import auth +from ..schemas.customer import CustomerPortalCustomer + +router = APIRouter(prefix="/customers", tags=["customers", APITag.documented]) + +CustomerID = Annotated[UUID4, Path(description="The customer ID.")] +CustomerNotFound = { + "description": "Customer not found.", + "model": ResourceNotFound.schema(), +} + + +@router.get( + "/{id}", + summary="Get Customer", + response_model=CustomerPortalCustomer, + responses={404: CustomerNotFound}, +) +async def get( + id: CustomerID, + auth_subject: auth.CustomerPortalRead, + session: AsyncSession = Depends(get_db_session), +) -> Customer: + """Get a customer by ID for the authenticated customer or user.""" + customer: Customer | None = None + if is_user(auth_subject): + customer = await customer_service.get_by_id_and_user( + session, id, auth_subject.subject + ) + elif is_customer(auth_subject) and auth_subject.subject.id == id: + customer = auth_subject.subject + + if customer is None: + raise ResourceNotFound() + + return customer diff --git a/server/polar/customer_portal/endpoints/oauth_accounts.py b/server/polar/customer_portal/endpoints/oauth_accounts.py index cb5ca932dc..2d18b20dea 100644 --- a/server/polar/customer_portal/endpoints/oauth_accounts.py +++ b/server/polar/customer_portal/endpoints/oauth_accounts.py @@ -1,3 +1,4 @@ +import uuid from typing import Any import structlog @@ -7,8 +8,18 @@ from httpx_oauth.clients.github import GitHubOAuth2 from httpx_oauth.exceptions import GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError - +from pydantic import UUID4 + +from polar.auth.models import ( + Customer, + is_anonymous, + is_customer, + is_user, +) from polar.config import settings +from polar.customer.service import customer as customer_service +from polar.customer_session.service import customer_session as customer_session_service +from polar.exceptions import PolarError from polar.integrations.github.client import Forbidden from polar.kit import jwt from polar.kit.http import ReturnTo, add_query_parameters, get_safe_return_url @@ -19,6 +30,7 @@ from polar.routing import APIRouter from .. import auth +from ..schemas.oauth_accounts import AuthorizeResponse router = APIRouter(prefix="/oauth-accounts", tags=["oauth-accounts", APITag.private]) @@ -37,15 +49,33 @@ } +class OAuthCallbackError(PolarError): + def __init__(self, message: str) -> None: + super().__init__(message, 400) + + @router.get("/authorize", name="customer_portal.oauth_accounts.authorize") async def authorize( request: Request, return_to: ReturnTo, - auth_subject: auth.CustomerPortalOAuthAccount, + auth_subject: auth.CustomerPortalWrite, platform: CustomerOAuthPlatform = Query(...), -) -> RedirectResponse: + customer_id: UUID4 = Query(...), + session: AsyncSession = Depends(get_db_session), +) -> AuthorizeResponse: + customer: Customer | None = None + if is_user(auth_subject): + customer = await customer_service.get_by_id_and_user( + session, customer_id, auth_subject.subject + ) + elif is_customer(auth_subject) and auth_subject.subject.id == customer_id: + customer = auth_subject.subject + + if customer is None: + raise Forbidden("Invalid customer") + state = { - "customer_id": str(auth_subject.subject.id), + "customer_id": str(customer.id), "platform": platform, "return_to": return_to, } @@ -57,7 +87,8 @@ async def authorize( redirect_uri=str(request.url_for("customer_portal.oauth_accounts.callback")), state=encoded_state, ) - return RedirectResponse(authorization_url, 303) + + return AuthorizeResponse(url=authorization_url) @router.get("/callback", name="customer_portal.oauth_accounts.callback") @@ -78,15 +109,33 @@ async def callback( except jwt.DecodeError as e: raise Forbidden("Invalid state") from e - if str(auth_subject.subject.id) != state_data["customer_id"]: - raise Forbidden("Invalid state") + customer_id = uuid.UUID(state_data.get("customer_id")) + customer: Customer | None = None + if is_user(auth_subject): + customer = await customer_service.get_by_id_and_user( + session, customer_id, auth_subject.subject + ) + elif is_anonymous(auth_subject): + # Trust the customer ID in the state for anonymous users + customer = await customer_service.get(session, customer_id) + + if customer is None: + raise Forbidden("Invalid customer") return_to = state_data["return_to"] platform = CustomerOAuthPlatform(state_data["platform"]) + redirect_url = get_safe_return_url(return_to) + # If the user is not authenticated, create a new customer session, we trust the customer ID in the state + if is_anonymous(auth_subject): + token, _ = await customer_session_service.create_customer_session( + session, customer + ) + redirect_url = add_query_parameters(redirect_url, customer_session_token=token) + if code is None or error is not None: - redirect_url = get_safe_return_url( - add_query_parameters(return_to, error=error or "Failed to authorize.") + redirect_url = add_query_parameters( + redirect_url, error=error or "Failed to authorize." ) return RedirectResponse(redirect_url, 303) @@ -96,10 +145,8 @@ async def callback( code, str(request.url_for("customer_portal.oauth_accounts.callback")) ) except GetAccessTokenError as e: - redirect_url = get_safe_return_url( - add_query_parameters( - return_to, error="Failed to get access token. Please try again later." - ) + redirect_url = add_query_parameters( + redirect_url, error="Failed to get access token. Please try again later." ) log.error("Failed to get access token", error=str(e)) return RedirectResponse(redirect_url, 303) @@ -107,11 +154,9 @@ async def callback( try: profile = await client.get_profile(oauth2_token_data["access_token"]) except GetProfileError as e: - redirect_url = get_safe_return_url( - add_query_parameters( - return_to, - error="Failed to get profile information. Please try again later.", - ) + redirect_url = add_query_parameters( + redirect_url, + error="Failed to get profile information. Please try again later.", ) log.error("Failed to get account ID", error=str(e)) return RedirectResponse(redirect_url, 303) @@ -124,8 +169,7 @@ async def callback( account_username=platform.get_account_username(profile), ) - customer = auth_subject.subject customer.set_oauth_account(oauth_account, platform) session.add(customer) - return RedirectResponse(state_data["return_to"]) + return RedirectResponse(redirect_url) diff --git a/server/polar/customer_portal/schemas/customer.py b/server/polar/customer_portal/schemas/customer.py new file mode 100644 index 0000000000..882afe3075 --- /dev/null +++ b/server/polar/customer_portal/schemas/customer.py @@ -0,0 +1,17 @@ +from polar.kit.address import Address +from polar.kit.schemas import IDSchema, Schema, TimestampedSchema +from polar.kit.tax import TaxID + + +class CustomerPortalOAuthAccount(Schema): + account_id: str + account_username: str | None + + +class CustomerPortalCustomer(IDSchema, TimestampedSchema): + email: str + email_verified: bool + name: str | None + billing_address: Address | None + tax_id: TaxID | None + oauth_accounts: dict[str, CustomerPortalOAuthAccount] diff --git a/server/polar/customer_portal/schemas/oauth_accounts.py b/server/polar/customer_portal/schemas/oauth_accounts.py new file mode 100644 index 0000000000..034a97d4bb --- /dev/null +++ b/server/polar/customer_portal/schemas/oauth_accounts.py @@ -0,0 +1,5 @@ +from polar.kit.schemas import Schema + + +class AuthorizeResponse(Schema): + url: str diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index 116f894be7..8499711a59 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -34,9 +34,9 @@ def get_account_key(self, account_id: str) -> str: def get_account_id(self, data: dict[str, Any]) -> str: if self == CustomerOAuthPlatform.github: - return data["id"] + return str(data["id"]) if self == CustomerOAuthPlatform.discord: - return data["id"] + return str(data["id"]) raise NotImplementedError() def get_account_username(self, data: dict[str, Any]) -> str: @@ -131,13 +131,22 @@ def set_oauth_account( self, oauth_account: CustomerOAuthAccount, platform: CustomerOAuthPlatform ) -> None: account_key = platform.get_account_key(oauth_account.account_id) - self._oauth_accounts[account_key] = dataclasses.asdict(oauth_account) + self._oauth_accounts = { + **self._oauth_accounts, + account_key: dataclasses.asdict(oauth_account), + } def remove_oauth_account( self, account_id: str, platform: CustomerOAuthPlatform ) -> None: account_key = platform.get_account_key(account_id) - self._oauth_accounts.pop(account_key, None) + self._oauth_accounts = { + k: v for k, v in self._oauth_accounts.items() if k != account_key + } + + @property + def oauth_accounts(self) -> dict[str, Any]: + return self._oauth_accounts @property def legacy_user_id(self) -> UUID: From f99d3820d71b9944c05fbe02542473bfe490f7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 12 Dec 2024 11:42:38 +0100 Subject: [PATCH 33/47] server/order: fix error when handling one time purchases invoice --- server/polar/order/service.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index c1872a0027..d2be3f896c 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -407,6 +407,9 @@ async def create_order_from_stripe( created_at=datetime.fromtimestamp(invoice.created, tz=UTC), ) + organization = await organization_service.get(session, product.organization_id) + assert organization is not None + # Get or create customer assert invoice.customer is not None if customer is None: @@ -414,7 +417,7 @@ async def create_order_from_stripe( get_expandable_id(invoice.customer) ) customer = await customer_service.get_or_create_from_stripe_customer( - session, stripe_customer, product.organization + session, stripe_customer, organization ) order.customer = customer @@ -456,10 +459,6 @@ async def create_order_from_stripe( order_id=order.id, ) - organization = await organization_service.get( - session, product.organization_id - ) - assert organization is not None await self.send_admin_notification(session, organization, order) await self.send_confirmation_email(session, organization, order) From 0b8359e66e013f58c08899d605f353e6589b3a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 12 Dec 2024 14:09:36 +0100 Subject: [PATCH 34/47] server/checkout: handle authenticated user properly during checkout --- server/polar/checkout/endpoints.py | 7 +- server/polar/checkout/service.py | 45 ++++++- server/polar/customer/service.py | 22 +++ server/tests/checkout/test_service.py | 155 +++++++++++++++++++++- server/tests/fixtures/random_objects.py | 39 +++++- server/tests/order/test_service.py | 7 +- server/tests/subscription/test_service.py | 5 +- 7 files changed, 263 insertions(+), 17 deletions(-) diff --git a/server/polar/checkout/endpoints.py b/server/polar/checkout/endpoints.py index e95fd72774..261f61fbf0 100644 --- a/server/polar/checkout/endpoints.py +++ b/server/polar/checkout/endpoints.py @@ -181,7 +181,7 @@ async def client_create( """Create a checkout session from a client. Suitable to build checkout links.""" ip_address = request.client.host if request.client else None return await checkout_service.client_create( - session, checkout_create, ip_geolocation_client, ip_address + session, checkout_create, auth_subject, ip_geolocation_client, ip_address ) @@ -223,6 +223,7 @@ async def client_update( async def client_confirm( client_secret: CheckoutClientSecret, checkout_confirm: CheckoutConfirm, + auth_subject: auth.CheckoutWeb, session: AsyncSession = Depends(get_db_session), locker: Locker = Depends(get_locker), ) -> Checkout: @@ -236,7 +237,9 @@ async def client_confirm( if checkout is None: raise ResourceNotFound() - return await checkout_service.confirm(session, locker, checkout, checkout_confirm) + return await checkout_service.confirm( + session, locker, auth_subject, checkout, checkout_confirm + ) @router.get("/client/{client_secret}/stream", include_in_schema=False) diff --git a/server/polar/checkout/service.py b/server/polar/checkout/service.py index eef172c7e7..af82117d33 100644 --- a/server/polar/checkout/service.py +++ b/server/polar/checkout/service.py @@ -8,7 +8,9 @@ from sqlalchemy.orm import contains_eager, joinedload, selectinload from polar.auth.models import ( + Anonymous, AuthSubject, + is_direct_user, is_organization, is_user, ) @@ -406,6 +408,7 @@ async def client_create( self, session: AsyncSession, checkout_create: CheckoutCreatePublic, + auth_subject: AuthSubject[User | Anonymous], ip_geolocation_client: ip_geolocation.IPGeolocationClient | None = None, ip_address: str | None = None, ) -> Checkout: @@ -487,7 +490,27 @@ async def client_create( customer=None, subscription=None, ) - if checkout_create.customer_email is not None: + if is_direct_user(auth_subject): + customer = await customer_service.get_by_user_and_organization( + session, auth_subject.subject, product.organization + ) + if customer is not None: + checkout.customer = customer + checkout.customer_email = customer.email + if checkout_create.subscription_id is not None: + ( + subscription, + subscription_customer, + ) = await self._get_validated_subscription( + session, + checkout_create.subscription_id, + product.organization_id, + ) + if subscription_customer == customer: + checkout.subscription = subscription + else: + checkout.customer_email = auth_subject.subject.email + elif checkout_create.customer_email is not None: checkout.customer_email = checkout_create.customer_email if checkout.payment_processor == PaymentProcessor.stripe: @@ -640,6 +663,7 @@ async def confirm( self, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[User | Anonymous], checkout: Checkout, checkout_confirm: CheckoutConfirm, ) -> Checkout: @@ -653,7 +677,7 @@ async def confirm( ) as discount_redemption: discount_redemption.checkout = checkout return await self._confirm_inner( - session, checkout, checkout_confirm + session, auth_subject, checkout, checkout_confirm ) except DiscountNotRedeemableError as e: raise PolarRequestValidationError( @@ -667,11 +691,14 @@ async def confirm( ] ) from e - return await self._confirm_inner(session, checkout, checkout_confirm) + return await self._confirm_inner( + session, auth_subject, checkout, checkout_confirm + ) async def _confirm_inner( self, session: AsyncSession, + auth_subject: AuthSubject[User | Anonymous], checkout: Checkout, checkout_confirm: CheckoutConfirm, ) -> Checkout: @@ -739,7 +766,9 @@ async def _confirm_inner( raise PolarRequestValidationError(errors) if checkout.payment_processor == PaymentProcessor.stripe: - customer = await self._create_or_update_customer(session, checkout) + customer = await self._create_or_update_customer( + session, auth_subject, checkout + ) checkout.customer = customer stripe_customer_id = customer.stripe_customer_id assert stripe_customer_id is not None @@ -1550,7 +1579,10 @@ def _get_required_confirm_fields(self, checkout: Checkout) -> set[str]: return fields async def _create_or_update_customer( - self, session: AsyncSession, checkout: Checkout + self, + session: AsyncSession, + auth_subject: AuthSubject[User | Anonymous], + checkout: Checkout, ) -> Customer: customer = checkout.customer if customer is None: @@ -1569,6 +1601,9 @@ async def _create_or_update_customer( organization=checkout.organization, ) + if is_direct_user(auth_subject): + await customer_service.link_user(session, customer, auth_subject.subject) + stripe_customer_id = customer.stripe_customer_id if stripe_customer_id is None: create_params: stripe_lib.Customer.CreateParams = {"email": customer.email} diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 5a1f0a6156..f7ef7d3cc6 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -177,6 +177,21 @@ async def get_by_id_and_user( result = await session.execute(statement) return result.scalar_one_or_none() + async def get_by_user_and_organization( + self, session: AsyncSession, user: User, organization: Organization + ) -> Customer | None: + statement = ( + select(Customer) + .join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id) + .where( + Customer.deleted_at.is_(None), + UserCustomer.user_id == user.id, + Customer.organization_id == organization.id, + ) + ) + result = await session.execute(statement) + return result.scalar_one_or_none() + async def get_by_stripe_customer_id( self, session: AsyncSession, stripe_customer_id: str ) -> Customer | None: @@ -221,6 +236,13 @@ async def get_or_create_from_stripe_customer( session.add(customer) return customer + async def link_user( + self, session: AsyncSession, customer: Customer, user: User + ) -> UserCustomer: + user_customer = UserCustomer(user=user, customer=customer) + session.add(user_customer) + return user_customer + def _get_readable_customer_statement( self, auth_subject: AuthSubject[User | Organization] ) -> Select[tuple[Customer]]: diff --git a/server/tests/checkout/test_service.py b/server/tests/checkout/test_service.py index 40271c3c1f..ed5124f875 100644 --- a/server/tests/checkout/test_service.py +++ b/server/tests/checkout/test_service.py @@ -11,7 +11,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm import joinedload -from polar.auth.models import Anonymous, AuthSubject +from polar.auth.models import Anonymous, AuthMethod, AuthSubject from polar.checkout.schemas import ( CheckoutConfirmStripe, CheckoutCreatePublic, @@ -31,6 +31,7 @@ PaymentRequired, ) from polar.checkout.service import checkout as checkout_service +from polar.customer.service import customer as customer_service from polar.customer_session.service import customer_session as customer_session_service from polar.discount.service import discount as discount_service from polar.enums import PaymentProcessor @@ -69,6 +70,7 @@ create_product, create_product_price_fixed, create_subscription, + create_user_customer, ) @@ -1048,6 +1050,7 @@ async def test_not_existing_price( CheckoutCreatePublic( product_price_id=uuid.uuid4(), ), + auth_subject, ) async def test_archived_price( @@ -1065,7 +1068,7 @@ async def test_archived_price( ) with pytest.raises(PolarRequestValidationError): await checkout_service.client_create( - session, CheckoutCreatePublic(product_price_id=price.id) + session, CheckoutCreatePublic(product_price_id=price.id), auth_subject ) async def test_archived_product( @@ -1083,6 +1086,7 @@ async def test_archived_product( CheckoutCreatePublic( product_price_id=product_one_time.prices[0].id, ), + auth_subject, ) async def test_valid_fixed_price( @@ -1094,7 +1098,7 @@ async def test_valid_fixed_price( price = product_one_time.prices[0] assert isinstance(price, ProductPriceFixed) checkout = await checkout_service.client_create( - session, CheckoutCreatePublic(product_price_id=price.id) + session, CheckoutCreatePublic(product_price_id=price.id), auth_subject ) assert checkout.product_price == price @@ -1111,7 +1115,7 @@ async def test_valid_free_price( price = product_one_time_free_price.prices[0] assert isinstance(price, ProductPriceFree) checkout = await checkout_service.client_create( - session, CheckoutCreatePublic(product_price_id=price.id) + session, CheckoutCreatePublic(product_price_id=price.id), auth_subject ) assert checkout.product_price == price @@ -1130,7 +1134,7 @@ async def test_valid_custom_price( price.preset_amount = 4242 checkout = await checkout_service.client_create( - session, CheckoutCreatePublic(product_price_id=price.id) + session, CheckoutCreatePublic(product_price_id=price.id), auth_subject ) assert checkout.product_price == price @@ -1151,6 +1155,7 @@ async def test_valid_from_legacy_checkout_link( CheckoutCreatePublic( product_price_id=price.id, from_legacy_checkout_link=True ), + auth_subject, ) assert checkout.product_price == price @@ -1158,6 +1163,90 @@ async def test_valid_from_legacy_checkout_link( assert checkout.amount == price.price_amount assert checkout.currency == price.price_currency + @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) + async def test_user_without_customer( + self, + session: AsyncSession, + auth_subject: AuthSubject[User], + product_one_time: Product, + ) -> None: + price = product_one_time.prices[0] + assert isinstance(price, ProductPriceFixed) + checkout = await checkout_service.client_create( + session, + CheckoutCreatePublic( + product_price_id=price.id, from_legacy_checkout_link=True + ), + auth_subject, + ) + + assert checkout.customer is None + assert checkout.customer_email == auth_subject.subject.email + + @pytest.mark.auth( + AuthSubjectFixture( + subject="user_second", method=AuthMethod.PERSONAL_ACCESS_TOKEN + ) + ) + async def test_indirect_user_with_customer( + self, + save_fixture: SaveFixture, + session: AsyncSession, + auth_subject: AuthSubject[User], + product_one_time: Product, + organization: Organization, + ) -> None: + await create_user_customer( + save_fixture, user=auth_subject.subject, organization=organization + ) + + price = product_one_time.prices[0] + assert isinstance(price, ProductPriceFixed) + checkout = await checkout_service.client_create( + session, + CheckoutCreatePublic( + product_price_id=price.id, from_legacy_checkout_link=True + ), + auth_subject, + ) + + assert checkout.customer is None + assert checkout.customer_email is None + + @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) + async def test_user_with_customer( + self, + save_fixture: SaveFixture, + stripe_service_mock: MagicMock, + session: AsyncSession, + auth_subject: AuthSubject[User], + product_one_time: Product, + organization: Organization, + ) -> None: + customer = await create_user_customer( + save_fixture, user=auth_subject.subject, organization=organization + ) + + stripe_service_mock.create_customer_session.return_value = SimpleNamespace( + client_secret="STRIPE_CUSTOMER_SESSION_CLIENT_SECRET" + ) + + price = product_one_time.prices[0] + assert isinstance(price, ProductPriceFixed) + checkout = await checkout_service.client_create( + session, + CheckoutCreatePublic( + product_price_id=price.id, from_legacy_checkout_link=True + ), + auth_subject, + ) + + assert checkout.customer == customer + assert ( + checkout.payment_processor_metadata["customer_session_client_secret"] + == "STRIPE_CUSTOMER_SESSION_CLIENT_SECRET" + ) + @pytest.mark.asyncio @pytest.mark.skip_db_asserts @@ -1731,12 +1820,14 @@ async def test_missing_amount_on_custom_price( self, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_custom: Checkout, ) -> None: with pytest.raises(PolarRequestValidationError): await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_custom, CheckoutConfirmStripe.model_validate( { @@ -1761,12 +1852,14 @@ async def test_missing_required_field( payload: dict[str, str], session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, ) -> None: with pytest.raises(PolarRequestValidationError): await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate(payload), ) @@ -1775,12 +1868,14 @@ async def test_not_open( self, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_confirmed_one_time: Checkout, ) -> None: with pytest.raises(NotOpenCheckout): await checkout_service.confirm( session, locker, + auth_subject, checkout_confirmed_one_time, CheckoutConfirmStripe.model_validate( {"confirmation_token_id": "CONFIRMATION_TOKEN_ID"} @@ -1792,6 +1887,7 @@ async def test_archived_price( save_fixture: SaveFixture, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, ) -> None: archived_price = await create_product_price_fixed( @@ -1804,6 +1900,7 @@ async def test_archived_price( await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate( { @@ -1820,6 +1917,7 @@ async def test_calculate_tax_error( calculate_tax_mock: AsyncMock, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, ) -> None: calculate_tax_mock.side_effect = IncompleteTaxLocation( @@ -1830,6 +1928,7 @@ async def test_calculate_tax_error( await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate( { @@ -1858,6 +1957,7 @@ async def test_valid_stripe( stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, ) -> None: stripe_service_mock.create_customer.return_value = SimpleNamespace( @@ -1869,6 +1969,7 @@ async def test_valid_stripe( checkout = await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate( { @@ -1920,6 +2021,7 @@ async def test_valid_fully_discounted_subscription( stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_discount_percentage_100: Checkout, discount_percentage_100: Discount, ) -> None: @@ -1932,6 +2034,7 @@ async def test_valid_fully_discounted_subscription( checkout = await checkout_service.confirm( session, locker, + auth_subject, checkout_discount_percentage_100, CheckoutConfirmStripe.model_validate( { @@ -1974,6 +2077,7 @@ async def test_valid_stripe_free( mocker: MockerFixture, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_free: Checkout, ) -> None: enqueue_job_mock = mocker.patch("polar.checkout.service.enqueue_job") @@ -1985,6 +2089,7 @@ async def test_valid_stripe_free( checkout = await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_free, CheckoutConfirmStripe.model_validate( { @@ -2012,6 +2117,7 @@ async def test_valid_stripe_existing_customer( stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], organization: Organization, checkout_one_time_fixed: Checkout, ) -> None: @@ -2031,6 +2137,7 @@ async def test_valid_stripe_existing_customer( checkout = await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate( { @@ -2049,6 +2156,7 @@ async def test_valid_stripe_existing_customer_email( stripe_service_mock: MagicMock, session: AsyncSession, locker: Locker, + auth_subject: AuthSubject[Anonymous], checkout_one_time_fixed: Checkout, customer: Customer, ) -> None: @@ -2059,6 +2167,7 @@ async def test_valid_stripe_existing_customer_email( checkout = await checkout_service.confirm( session, locker, + auth_subject, checkout_one_time_fixed, CheckoutConfirmStripe.model_validate( { @@ -2074,6 +2183,42 @@ async def test_valid_stripe_existing_customer_email( assert checkout.customer == customer stripe_service_mock.update_customer.assert_called_once() + @pytest.mark.auth(AuthSubjectFixture(subject="user_second")) + async def test_link_customer_to_authenticated_user( + self, + stripe_service_mock: MagicMock, + session: AsyncSession, + locker: Locker, + auth_subject: AuthSubject[User], + checkout_one_time_fixed: Checkout, + ) -> None: + stripe_service_mock.create_customer.return_value = SimpleNamespace( + id="STRIPE_CUSTOMER_ID" + ) + stripe_service_mock.create_payment_intent.return_value = SimpleNamespace( + client_secret="CLIENT_SECRET", status="succeeded" + ) + checkout = await checkout_service.confirm( + session, + locker, + auth_subject, + checkout_one_time_fixed, + CheckoutConfirmStripe.model_validate( + { + "confirmation_token_id": "CONFIRMATION_TOKEN_ID", + "customer_name": "Customer Name", + "customer_email": "customer@example.com", + "customer_billing_address": {"country": "FR"}, + } + ), + ) + + assert checkout.customer is not None + linked_customer = await customer_service.get_by_id_and_user( + session, checkout.customer.id, auth_subject.subject + ) + assert linked_customer is not None + def build_stripe_payment_intent( *, diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 135c123116..bd62c5e771 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -39,6 +39,7 @@ Repository, Subscription, User, + UserCustomer, UserOrganization, ) from polar.models.benefit import BenefitType @@ -79,6 +80,10 @@ def rstr(prefix: str) -> str: return prefix + "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) +def lstr(suffix: str) -> str: + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + suffix + + async def create_organization( save_fixture: SaveFixture, name_prefix: str = "testorg", **kwargs: Any ) -> Organization: @@ -1271,7 +1276,12 @@ async def customer( save_fixture: SaveFixture, organization: Organization, ) -> Customer: - return await create_customer(save_fixture, organization=organization) + return await create_customer( + save_fixture, + organization=organization, + email=lstr("customer@example.com"), + stripe_customer_id=lstr("STRIPE_CUSTOMER_ID"), + ) @pytest_asyncio.fixture @@ -1282,11 +1292,34 @@ async def customer_second( return await create_customer( save_fixture, organization=organization, - email="customer.second@example.com", - stripe_customer_id="STRIPE_CUSTOMER_ID_2", + email=lstr("customer.second@example.com"), + stripe_customer_id=lstr("STRIPE_CUSTOMER_ID_2"), ) +async def create_user_customer( + save_fixture: SaveFixture, + *, + user: User, + organization: Organization, + email: str = "user.customer@example.com", + email_verified: bool = False, + name: str = "Customer", + stripe_customer_id: str = "STRIPE_USER_CUSTOMER_ID", +) -> Customer: + customer = await create_customer( + save_fixture, + organization=organization, + email=email, + email_verified=email_verified, + name=name, + stripe_customer_id=stripe_customer_id, + ) + user_customer = UserCustomer(user=user, customer=customer) + await save_fixture(user_customer) + return customer + + @pytest_asyncio.fixture async def subscription( save_fixture: SaveFixture, diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index a5199d067d..1d50359709 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -1,7 +1,7 @@ import time from datetime import datetime from types import SimpleNamespace -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -329,6 +329,7 @@ async def test_subscription_no_account( subscription_id=subscription.stripe_subscription_id, lines=[(product.prices[0].stripe_price_id, False, None)], created=created_unix_timestamp, + customer_id=cast(str, subscription.customer.stripe_customer_id), ) payment_transaction = await create_transaction( @@ -383,6 +384,7 @@ async def test_subscription_proration( (product.prices[0].stripe_price_id, False, None), ], created=created_unix_timestamp, + customer_id=cast(str, subscription.customer.stripe_customer_id), ) payment_transaction = await create_transaction( @@ -422,6 +424,7 @@ async def test_subscription_only_proration( "metadata": {"product_price_id": str(product.prices[0].id)} }, created=created_unix_timestamp, + customer_id=cast(str, subscription.customer.stripe_customer_id), ) payment_transaction = await create_transaction( @@ -458,6 +461,7 @@ async def test_subscription_with_account( subscription_id=subscription.stripe_subscription_id, lines=[(product.prices[0].stripe_price_id, False, None)], created=created_unix_timestamp, + customer_id=cast(str, subscription.customer.stripe_customer_id), ) invoice_total = invoice.total - (invoice.tax or 0) @@ -542,6 +546,7 @@ async def test_subscription_applied_balance( "metadata": {"product_price_id": str(product.prices[0].id)} }, created=created_unix_timestamp, + customer_id=cast(str, subscription.customer.stripe_customer_id), ) order = await order_service.create_order_from_stripe(session, invoice=invoice) diff --git a/server/tests/subscription/test_service.py b/server/tests/subscription/test_service.py index 3417691ac5..3158ac557c 100644 --- a/server/tests/subscription/test_service.py +++ b/server/tests/subscription/test_service.py @@ -1,4 +1,5 @@ from datetime import UTC, datetime, timedelta +from typing import cast from unittest.mock import MagicMock, call import pytest @@ -284,7 +285,9 @@ async def test_existing_customer( product: Product, customer: Customer, ) -> None: - stripe_customer = construct_stripe_customer() + stripe_customer = construct_stripe_customer( + id=cast(str, customer.stripe_customer_id) + ) get_customer_mock = stripe_service_mock.get_customer get_customer_mock.return_value = stripe_customer From f3702b164bf2ccd06e8cf9c8012485a8a9eee57c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 12 Dec 2024 15:18:19 +0100 Subject: [PATCH 35/47] server/customer_portal: add a SSE endpoint for customer --- server/polar/customer/service.py | 14 ++++++++++ .../customer_portal/endpoints/customer.py | 26 ++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index f7ef7d3cc6..f30bddc8b0 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -192,6 +192,20 @@ async def get_by_user_and_organization( result = await session.execute(statement) return result.scalar_one_or_none() + async def get_by_user( + self, session: AsyncSession, user: User + ) -> Sequence[Customer]: + statement = ( + select(Customer) + .join(UserCustomer, onclause=UserCustomer.customer_id == Customer.id) + .where( + Customer.deleted_at.is_(None), + UserCustomer.user_id == user.id, + ) + ) + result = await session.execute(statement) + return result.unique().scalars().all() + async def get_by_stripe_customer_id( self, session: AsyncSession, stripe_customer_id: str ) -> Customer | None: diff --git a/server/polar/customer_portal/endpoints/customer.py b/server/polar/customer_portal/endpoints/customer.py index 7c4a0744ab..11f548017c 100644 --- a/server/polar/customer_portal/endpoints/customer.py +++ b/server/polar/customer_portal/endpoints/customer.py @@ -1,14 +1,18 @@ from typing import Annotated -from fastapi import Depends, Path +from fastapi import Depends, Path, Request from pydantic import UUID4 +from sse_starlette import EventSourceResponse from polar.auth.models import is_customer, is_user from polar.customer.service import customer as customer_service +from polar.eventstream.endpoints import subscribe +from polar.eventstream.service import Receivers from polar.exceptions import ResourceNotFound from polar.models import Customer from polar.openapi import APITag from polar.postgres import AsyncSession, get_db_session +from polar.redis import Redis, get_redis from polar.routing import APIRouter from .. import auth @@ -23,6 +27,26 @@ } +@router.get("/stream", include_in_schema=False) +async def stream( + request: Request, + auth_subject: auth.CustomerPortalRead, + session: AsyncSession = Depends(get_db_session), + redis: Redis = Depends(get_redis), +) -> EventSourceResponse: + if is_user(auth_subject): + customers = await customer_service.get_by_user(session, auth_subject.subject) + elif is_customer(auth_subject): + customers = [auth_subject.subject] + + channels: list[str] = [] + for customer in customers: + receivers = Receivers(customer_id=customer.id) + channels = [*channels, *receivers.get_channels()] + + return EventSourceResponse(subscribe(redis, channels, request)) + + @router.get( "/{id}", summary="Get Customer", From 527e04253bf03198c3d7ab906a049a4ed4ad2e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 13 Dec 2024 09:59:09 +0100 Subject: [PATCH 36/47] server/customer_portal: add an OTP code mechanism to generate customer session from the web --- ...2024-12-12-1627_add_customersessioncode.py | 87 +++++++++++ server/polar/config.py | 2 + .../customer_session_code.html | 24 +++ .../customer_portal/endpoints/__init__.py | 4 + .../endpoints/customer_session.py | 53 +++++++ .../customer_portal/endpoints/organization.py | 41 +++++ .../schemas/customer_session.py | 16 ++ .../service/customer_session.py | 140 ++++++++++++++++++ .../customer_portal/service/organization.py | 21 +++ server/polar/email/email_templates/base.html | 23 ++- server/polar/models/__init__.py | 2 + server/polar/models/customer_session_code.py | 31 ++++ server/polar/order/service.py | 4 +- server/polar/subscription/service.py | 5 +- .../tests/magic_link/testdata/magic_link.html | 23 ++- .../testdata/magic_link_return_to.html | 23 ++- ...ainerCreateAccountNotificationPayload.html | 23 ++- ...tainerNewPaidSubscriptionNotification.html | 23 ++- ..._MaintainerNewProductSaleNotification.html | 23 ++- ...rmationPendingdNotification_no_stripe.html | 23 ++- ...ationPendingdNotification_with_stripe.html | 23 ++- ...erPledgeCreatedNotification_anonymous.html | 23 ++- ...erPledgeCreatedNotification_no_stripe.html | 23 ++- ...CreatedNotification_pay_on_completion.html | 23 ++- ...PledgeCreatedNotification_with_stripe.html | 23 ++- ...test_MaintainerPledgePaidNotification.html | 23 ++- ...rPledgePendingdNotification_no_stripe.html | 23 ++- ...ledgePendingdNotification_with_stripe.html | 23 ++- ...dIssueConfirmationPendingNotification.html | 23 ++- ...ationPendingNotification_with_account.html | 23 ++- ...tainerPledgedIssuePendingNotification.html | 23 ++- ...IssuePendingNotification_with_account.html | 23 ++- ...test_PledgerPledgePendingNotification.html | 23 ++- ...PendingNotification_pay_on_completion.html | 23 ++- .../testdata/test_RewardPaidNotification.html | 23 ++- ...st_TeamAdminMemberPledgedNotification.html | 23 ++- 36 files changed, 932 insertions(+), 27 deletions(-) create mode 100644 server/migrations/versions/2024-12-12-1627_add_customersessioncode.py create mode 100644 server/polar/customer_portal/email_templates/customer_session_code.html create mode 100644 server/polar/customer_portal/endpoints/customer_session.py create mode 100644 server/polar/customer_portal/endpoints/organization.py create mode 100644 server/polar/customer_portal/schemas/customer_session.py create mode 100644 server/polar/customer_portal/service/customer_session.py create mode 100644 server/polar/customer_portal/service/organization.py create mode 100644 server/polar/models/customer_session_code.py diff --git a/server/migrations/versions/2024-12-12-1627_add_customersessioncode.py b/server/migrations/versions/2024-12-12-1627_add_customersessioncode.py new file mode 100644 index 0000000000..bee690677d --- /dev/null +++ b/server/migrations/versions/2024-12-12-1627_add_customersessioncode.py @@ -0,0 +1,87 @@ +"""Add CustomerSessionCode + +Revision ID: 648a1268ab97 +Revises: e47b6d16d3e0 +Create Date: 2024-12-12 16:27:41.581614 + +""" + +import sqlalchemy as sa +from alembic import op + +# Polar Custom Imports + +# revision identifiers, used by Alembic. +revision = "648a1268ab97" +down_revision = "e47b6d16d3e0" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "customer_session_codes", + sa.Column("code", sa.CHAR(length=64), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("customer_id", sa.Uuid(), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("modified_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("deleted_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["customer_id"], + ["customers.id"], + name=op.f("customer_session_codes_customer_id_fkey"), + ondelete="cascade", + ), + sa.PrimaryKeyConstraint("id", name=op.f("customer_session_codes_pkey")), + sa.UniqueConstraint("code", name=op.f("customer_session_codes_code_key")), + ) + op.create_index( + op.f("ix_customer_session_codes_created_at"), + "customer_session_codes", + ["created_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_session_codes_deleted_at"), + "customer_session_codes", + ["deleted_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_session_codes_expires_at"), + "customer_session_codes", + ["expires_at"], + unique=False, + ) + op.create_index( + op.f("ix_customer_session_codes_modified_at"), + "customer_session_codes", + ["modified_at"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + op.f("ix_customer_session_codes_modified_at"), + table_name="customer_session_codes", + ) + op.drop_index( + op.f("ix_customer_session_codes_expires_at"), + table_name="customer_session_codes", + ) + op.drop_index( + op.f("ix_customer_session_codes_deleted_at"), + table_name="customer_session_codes", + ) + op.drop_index( + op.f("ix_customer_session_codes_created_at"), + table_name="customer_session_codes", + ) + op.drop_table("customer_session_codes") + # ### end Alembic commands ### diff --git a/server/polar/config.py b/server/polar/config.py index 1f8d70eb45..a053c70810 100644 --- a/server/polar/config.py +++ b/server/polar/config.py @@ -64,6 +64,8 @@ class Settings(BaseSettings): # Customer session CUSTOMER_SESSION_TTL: timedelta = timedelta(hours=1) + CUSTOMER_SESSION_CODE_TTL: timedelta = timedelta(minutes=30) + CUSTOMER_SESSION_CODE_LENGTH: int = 6 # Magic link MAGIC_LINK_TTL_SECONDS: int = 60 * 30 # 30 minutes diff --git a/server/polar/customer_portal/email_templates/customer_session_code.html b/server/polar/customer_portal/email_templates/customer_session_code.html new file mode 100644 index 0000000000..a1fe7e45fd --- /dev/null +++ b/server/polar/customer_portal/email_templates/customer_session_code.html @@ -0,0 +1,24 @@ +{% extends "base.html" %} + +{% block body %} +

Hi,

+

Here is your code to access your {{ featured_organization.name }} purchases. Click the button below to complete the login process. This code + is only valid for the next {{ code_lifetime_minutes }} minutes.

+ + + + + + + + + + + + +{% endblock %} diff --git a/server/polar/customer_portal/endpoints/__init__.py b/server/polar/customer_portal/endpoints/__init__.py index f365ebd35a..8c2fa2d4f1 100644 --- a/server/polar/customer_portal/endpoints/__init__.py +++ b/server/polar/customer_portal/endpoints/__init__.py @@ -2,18 +2,22 @@ from .benefit_grant import router as benefit_grant_router from .customer import router as customer_router +from .customer_session import router as customer_session_router from .downloadables import router as downloadables_router from .license_keys import router as license_keys_router from .oauth_accounts import router as oauth_accounts_router from .order import router as order_router +from .organization import router as organization_router from .subscription import router as subscription_router router = APIRouter(prefix="/customer-portal", tags=["customer_portal"]) router.include_router(benefit_grant_router) router.include_router(customer_router) +router.include_router(customer_session_router) router.include_router(downloadables_router) router.include_router(license_keys_router) router.include_router(oauth_accounts_router) router.include_router(order_router) +router.include_router(organization_router) router.include_router(subscription_router) diff --git a/server/polar/customer_portal/endpoints/customer_session.py b/server/polar/customer_portal/endpoints/customer_session.py new file mode 100644 index 0000000000..263409683a --- /dev/null +++ b/server/polar/customer_portal/endpoints/customer_session.py @@ -0,0 +1,53 @@ +from fastapi import Depends + +from polar.kit.db.postgres import AsyncSession +from polar.openapi import APITag +from polar.postgres import get_db_session +from polar.routing import APIRouter + +from ..schemas.customer_session import ( + CustomerSessionCodeAuthenticateRequest, + CustomerSessionCodeAuthenticateResponse, + CustomerSessionCodeRequest, +) +from ..service.customer_session import CustomerDoesNotExist, OrganizationDoesNotExist +from ..service.customer_session import customer_session as customer_session_service + +router = APIRouter( + prefix="/customer-session", tags=["customer-session", APITag.private] +) + + +@router.post( + "/request", name="customer_portal.customer_session.request", status_code=202 +) +async def request( + customer_session_code_request: CustomerSessionCodeRequest, + session: AsyncSession = Depends(get_db_session), +) -> None: + try: + customer_session_code, code = await customer_session_service.request( + session, + customer_session_code_request.email, + customer_session_code_request.organization_id, + ) + except (CustomerDoesNotExist, OrganizationDoesNotExist): + # We don't want to leak information about whether the customer or organization exists + return + + await customer_session_service.send( + session, + customer_session_code, + code, + ) + + +@router.post("/authenticate", name="customer_portal.customer_session.authenticate") +async def authenticate( + authenticated_request: CustomerSessionCodeAuthenticateRequest, + session: AsyncSession = Depends(get_db_session), +) -> CustomerSessionCodeAuthenticateResponse: + token, _ = await customer_session_service.authenticate( + session, authenticated_request.code + ) + return CustomerSessionCodeAuthenticateResponse(token=token) diff --git a/server/polar/customer_portal/endpoints/organization.py b/server/polar/customer_portal/endpoints/organization.py new file mode 100644 index 0000000000..d4c8f04151 --- /dev/null +++ b/server/polar/customer_portal/endpoints/organization.py @@ -0,0 +1,41 @@ +from typing import Annotated + +from fastapi import Depends, Path + +from polar.exceptions import ResourceNotFound +from polar.models import Organization +from polar.openapi import APITag +from polar.organization.schemas import Organization as OrganizationSchema +from polar.postgres import AsyncSession, get_db_session +from polar.routing import APIRouter + +from ..service.organization import ( + customer_organization as customer_organization_service, +) + +router = APIRouter(prefix="/organizations", tags=["organizations", APITag.documented]) + +OrganizationSlug = Annotated[str, Path(description="The organization slug.")] +OrganizationNotFound = { + "description": "Organization not found.", + "model": ResourceNotFound.schema(), +} + + +@router.get( + "/{slug}", + summary="Get Organization", + response_model=OrganizationSchema, + responses={404: OrganizationNotFound}, +) +async def get( + slug: OrganizationSlug, + session: AsyncSession = Depends(get_db_session), +) -> Organization: + """Get a customer portal's organization by slug.""" + organization = await customer_organization_service.get_by_slug(session, slug) + + if organization is None: + raise ResourceNotFound() + + return organization diff --git a/server/polar/customer_portal/schemas/customer_session.py b/server/polar/customer_portal/schemas/customer_session.py new file mode 100644 index 0000000000..581ad6f8b0 --- /dev/null +++ b/server/polar/customer_portal/schemas/customer_session.py @@ -0,0 +1,16 @@ +from pydantic import UUID4 + +from polar.kit.schemas import EmailStrDNS, Schema + + +class CustomerSessionCodeRequest(Schema): + email: EmailStrDNS + organization_id: UUID4 + + +class CustomerSessionCodeAuthenticateRequest(Schema): + code: str + + +class CustomerSessionCodeAuthenticateResponse(Schema): + token: str diff --git a/server/polar/customer_portal/service/customer_session.py b/server/polar/customer_portal/service/customer_session.py new file mode 100644 index 0000000000..80d92b3d54 --- /dev/null +++ b/server/polar/customer_portal/service/customer_session.py @@ -0,0 +1,140 @@ +import datetime +import secrets +import string +import uuid +from math import ceil + +from sqlalchemy import select + +from polar.config import settings +from polar.customer.service import customer as customer_service +from polar.customer_session.service import customer_session as customer_session_service +from polar.email.renderer import get_email_renderer +from polar.email.sender import get_email_sender +from polar.exceptions import PolarError +from polar.kit.crypto import get_token_hash +from polar.kit.utils import utc_now +from polar.models import CustomerSession, CustomerSessionCode, Organization +from polar.organization.service import organization as organization_service +from polar.postgres import AsyncSession + + +class CustomerSessionError(PolarError): ... + + +class OrganizationDoesNotExist(CustomerSessionError): + def __init__(self, organization_id: uuid.UUID) -> None: + self.organization_id = organization_id + message = f"Organization {organization_id} does not exist." + super().__init__(message) + + +class CustomerDoesNotExist(CustomerSessionError): + def __init__(self, email: str, organization: Organization) -> None: + self.email = email + self.organization = organization + message = f"Customer does not exist for email {email} and organization {organization.id}." + super().__init__(message) + + +class CustomerSessionCodeInvalidOrExpired(CustomerSessionError): + def __init__(self) -> None: + super().__init__( + "This customer session code is invalid or has expired.", status_code=401 + ) + + +class CustomerSessionService: + async def request( + self, session: AsyncSession, email: str, organization_id: uuid.UUID + ) -> tuple[CustomerSessionCode, str]: + organization = await organization_service.get(session, organization_id) + if organization is None: + raise OrganizationDoesNotExist(organization_id) + + customer = await customer_service.get_by_email_and_organization( + session, email, organization + ) + if customer is None: + raise CustomerDoesNotExist(email, organization) + + code, code_hash = self._generate_code_hash() + + customer_session_code = CustomerSessionCode(code=code_hash, customer=customer) + session.add(customer_session_code) + + return customer_session_code, code + + async def send( + self, + session: AsyncSession, + customer_session_code: CustomerSessionCode, + code: str, + ) -> None: + email_renderer = get_email_renderer( + {"customer_portal": "polar.customer_portal"} + ) + email_sender = get_email_sender() + + customer = customer_session_code.customer + organization = await organization_service.get( + session, customer_session_code.customer.organization_id + ) + assert organization is not None + + delta = customer_session_code.expires_at - utc_now() + code_lifetime_minutes = int(ceil(delta.seconds / 60)) + + subject, body = email_renderer.render_from_template( + f"Access your {organization.name} purchases", + "customer_portal/customer_session_code.html", + { + "featured_organization": organization, + "code": code, + "code_lifetime_minutes": code_lifetime_minutes, + "url": settings.generate_frontend_url( + f"/{organization.slug}/portal/authenticate" + ), + "current_year": datetime.datetime.now().year, + }, + ) + + email_sender.send_to_user( + to_email_addr=customer.email, subject=subject, html_content=body + ) + + async def authenticate( + self, session: AsyncSession, code: str + ) -> tuple[str, CustomerSession]: + code_hash = get_token_hash(code, secret=settings.SECRET) + + statement = select(CustomerSessionCode).where( + CustomerSessionCode.expires_at > utc_now(), + CustomerSessionCode.code == code_hash, + ) + result = await session.execute(statement) + customer_session_code = result.scalar_one_or_none() + + if customer_session_code is None: + raise CustomerSessionCodeInvalidOrExpired() + + customer = customer_session_code.customer + customer.email_verified = True + session.add(customer) + + await session.delete(customer_session_code) + + return await customer_session_service.create_customer_session( + session, customer_session_code.customer + ) + + def _generate_code_hash(self) -> tuple[str, str]: + code = "".join( + secrets.choice(string.ascii_uppercase + string.digits) + for _ in range(settings.CUSTOMER_SESSION_CODE_LENGTH) + ) + code_hash = get_token_hash(code, secret=settings.SECRET) + return code, code_hash + + +customer_session = CustomerSessionService() diff --git a/server/polar/customer_portal/service/organization.py b/server/polar/customer_portal/service/organization.py new file mode 100644 index 0000000000..760a1e2ed2 --- /dev/null +++ b/server/polar/customer_portal/service/organization.py @@ -0,0 +1,21 @@ +from sqlalchemy import select + +from polar.kit.services import ResourceServiceReader +from polar.models import Organization +from polar.postgres import AsyncSession + + +class CustomerOrganizationService(ResourceServiceReader[Organization]): + async def get_by_slug( + self, session: AsyncSession, slug: str + ) -> Organization | None: + statement = select(Organization).where( + Organization.deleted_at.is_(None), + Organization.blocked_at.is_(None), + Organization.slug == slug, + ) + result = await session.execute(statement) + return result.unique().scalar_one_or_none() + + +customer_organization = CustomerOrganizationService(Organization) diff --git a/server/polar/email/email_templates/base.html b/server/polar/email/email_templates/base.html index b0a82dc1ae..68e75f73a8 100644 --- a/server/polar/email/email_templates/base.html +++ b/server/polar/email/email_templates/base.html @@ -224,6 +224,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -461,7 +481,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/polar/models/__init__.py b/server/polar/models/__init__.py index afe43110d4..f976ddd046 100644 --- a/server/polar/models/__init__.py +++ b/server/polar/models/__init__.py @@ -9,6 +9,7 @@ from .custom_field import CustomField from .customer import Customer from .customer_session import CustomerSession +from .customer_session_code import CustomerSessionCode from .discount import Discount from .discount_product import DiscountProduct from .discount_redemption import DiscountRedemption @@ -64,6 +65,7 @@ "CheckoutLink", "Customer", "CustomerSession", + "CustomerSessionCode", "CustomField", "Discount", "DiscountProduct", diff --git a/server/polar/models/customer_session_code.py b/server/polar/models/customer_session_code.py new file mode 100644 index 0000000000..abb0c1e7e8 --- /dev/null +++ b/server/polar/models/customer_session_code.py @@ -0,0 +1,31 @@ +from datetime import datetime +from uuid import UUID + +from sqlalchemy import CHAR, TIMESTAMP, ForeignKey, Uuid +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship + +from polar.config import settings +from polar.kit.db.models.base import RecordModel +from polar.kit.utils import utc_now +from polar.models.customer import Customer + + +def get_expires_at() -> datetime: + return utc_now() + settings.CUSTOMER_SESSION_CODE_TTL + + +class CustomerSessionCode(RecordModel): + __tablename__ = "customer_session_codes" + + code: Mapped[str] = mapped_column(CHAR(64), unique=True, nullable=False) + expires_at: Mapped[datetime] = mapped_column( + TIMESTAMP(timezone=True), nullable=False, index=True, default=get_expires_at + ) + + customer_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("customers.id", ondelete="cascade"), nullable=False + ) + + @declared_attr + def customer(cls) -> Mapped[Customer]: + return relationship(Customer, lazy="joined") diff --git a/server/polar/order/service.py b/server/polar/order/service.py index d2be3f896c..9e9a55c1d1 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -510,7 +510,9 @@ async def send_confirmation_email( { "featured_organization": organization, "product": product, - "url": f"{settings.FRONTEND_BASE_URL}/purchases/products/{order.id}", + "url": settings.generate_frontend_url( + f"/{organization.slug}/portal/orders/{order.id}" + ), "current_year": datetime.now().year, }, ) diff --git a/server/polar/subscription/service.py b/server/polar/subscription/service.py index 3ce6065991..e17d984109 100644 --- a/server/polar/subscription/service.py +++ b/server/polar/subscription/service.py @@ -692,9 +692,8 @@ async def send_confirmation_email( { "featured_organization": featured_organization, "product": product, - "url": ( - f"{settings.FRONTEND_BASE_URL}" - f"/purchases/subscriptions/{subscription.id}" + "url": settings.generate_frontend_url( + f"/{featured_organization.slug}/portal/subscriptions/{subscription.id}" ), "current_year": datetime.now().year, }, diff --git a/server/tests/magic_link/testdata/magic_link.html b/server/tests/magic_link/testdata/magic_link.html index 76a16527fb..79805dfbf3 100644 --- a/server/tests/magic_link/testdata/magic_link.html +++ b/server/tests/magic_link/testdata/magic_link.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/magic_link/testdata/magic_link_return_to.html b/server/tests/magic_link/testdata/magic_link_return_to.html index 7049d5f47c..e6997e8be8 100644 --- a/server/tests/magic_link/testdata/magic_link_return_to.html +++ b/server/tests/magic_link/testdata/magic_link_return_to.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerCreateAccountNotificationPayload.html b/server/tests/notifications/testdata/test_MaintainerCreateAccountNotificationPayload.html index 6f1b3b1c32..baffadb8a8 100644 --- a/server/tests/notifications/testdata/test_MaintainerCreateAccountNotificationPayload.html +++ b/server/tests/notifications/testdata/test_MaintainerCreateAccountNotificationPayload.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerNewPaidSubscriptionNotification.html b/server/tests/notifications/testdata/test_MaintainerNewPaidSubscriptionNotification.html index 76a2b3d0f5..56fabc5e5e 100644 --- a/server/tests/notifications/testdata/test_MaintainerNewPaidSubscriptionNotification.html +++ b/server/tests/notifications/testdata/test_MaintainerNewPaidSubscriptionNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerNewProductSaleNotification.html b/server/tests/notifications/testdata/test_MaintainerNewProductSaleNotification.html index 299511511b..ebc56d1bfb 100644 --- a/server/tests/notifications/testdata/test_MaintainerNewProductSaleNotification.html +++ b/server/tests/notifications/testdata/test_MaintainerNewProductSaleNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_no_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_no_stripe.html index e87a064520..5846a40f2a 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_no_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_no_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_with_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_with_stripe.html index be69daaf75..64f3ac3eda 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_with_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeConfirmationPendingdNotification_with_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_anonymous.html b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_anonymous.html index e7a8561b30..8b7172cabd 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_anonymous.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_anonymous.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_no_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_no_stripe.html index 4cb3c51ad1..9d5a3c5250 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_no_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_no_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_pay_on_completion.html b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_pay_on_completion.html index 9920a81a14..1a5f813cc8 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_pay_on_completion.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_pay_on_completion.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_with_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_with_stripe.html index 64318356f2..763a75810d 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_with_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgeCreatedNotification_with_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgePaidNotification.html b/server/tests/notifications/testdata/test_MaintainerPledgePaidNotification.html index c6d94ccd85..cfe5a30bd0 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgePaidNotification.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgePaidNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_no_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_no_stripe.html index 29c4669eb7..504827b001 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_no_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_no_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_with_stripe.html b/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_with_stripe.html index d8cd120157..540bc3818c 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_with_stripe.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgePendingdNotification_with_stripe.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification.html b/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification.html index c4c074b2f4..cfa2be1912 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification_with_account.html b/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification_with_account.html index 058f6c2c51..525955140b 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification_with_account.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgedIssueConfirmationPendingNotification_with_account.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification.html b/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification.html index 24b5bd1c7d..3da7a6b9ec 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification_with_account.html b/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification_with_account.html index 1722dbe0dd..464036dc1e 100644 --- a/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification_with_account.html +++ b/server/tests/notifications/testdata/test_MaintainerPledgedIssuePendingNotification_with_account.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_PledgerPledgePendingNotification.html b/server/tests/notifications/testdata/test_PledgerPledgePendingNotification.html index b7fcdb47ae..fb80a7ebcb 100644 --- a/server/tests/notifications/testdata/test_PledgerPledgePendingNotification.html +++ b/server/tests/notifications/testdata/test_PledgerPledgePendingNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_PledgerPledgePendingNotification_pay_on_completion.html b/server/tests/notifications/testdata/test_PledgerPledgePendingNotification_pay_on_completion.html index 1fefccc011..6b7a62a88c 100644 --- a/server/tests/notifications/testdata/test_PledgerPledgePendingNotification_pay_on_completion.html +++ b/server/tests/notifications/testdata/test_PledgerPledgePendingNotification_pay_on_completion.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_RewardPaidNotification.html b/server/tests/notifications/testdata/test_RewardPaidNotification.html index ef6fce21b8..1f91094823 100644 --- a/server/tests/notifications/testdata/test_RewardPaidNotification.html +++ b/server/tests/notifications/testdata/test_RewardPaidNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } diff --git a/server/tests/notifications/testdata/test_TeamAdminMemberPledgedNotification.html b/server/tests/notifications/testdata/test_TeamAdminMemberPledgedNotification.html index ba4780a6f4..cb39dd01e1 100644 --- a/server/tests/notifications/testdata/test_TeamAdminMemberPledgedNotification.html +++ b/server/tests/notifications/testdata/test_TeamAdminMemberPledgedNotification.html @@ -226,6 +226,26 @@ font-size: 15px; } + /* OTP Code ------------------------------ */ + + .otp { + width: 100%; + margin: 0; + padding: 24px; + -premailer-width: 100%; + -premailer-cellpadding: 0; + -premailer-cellspacing: 0; + background-color: #F4F4F7; + border: 0; + } + + .otp_heading { + text-align: center; + margin: 0; + font-size: 48px; + letter-spacing: 10px; + } + /* Social Icons ------------------------------ */ .social { @@ -463,7 +483,8 @@ } .attributes_content, - .discount { + .discount, + .otp { background-color: #26282b !important; } From fcfca9c9f39b898d3f7a736b831511980c7baf03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 13 Dec 2024 10:47:02 +0100 Subject: [PATCH 37/47] server/customer: fix schema names --- server/polar/customer/endpoints.py | 2 ++ server/polar/customer/schemas.py | 23 ++++++++++++++++++++--- server/polar/customer/service.py | 11 ++++++++++- server/polar/license_key/schemas.py | 14 +++----------- server/polar/order/schemas.py | 10 ++-------- server/polar/storefront/schemas.py | 8 ++++---- server/polar/subscription/endpoints.py | 5 +++++ server/polar/subscription/schemas.py | 11 ++--------- server/polar/subscription/service.py | 4 ++++ 9 files changed, 52 insertions(+), 36 deletions(-) diff --git a/server/polar/customer/endpoints.py b/server/polar/customer/endpoints.py index b10398bde0..bf97d20098 100644 --- a/server/polar/customer/endpoints.py +++ b/server/polar/customer/endpoints.py @@ -38,6 +38,7 @@ async def list( organization_id: MultipleQueryFilter[OrganizationID] | None = Query( None, title="OrganizationID Filter", description="Filter by organization ID." ), + query: str | None = Query(None, description="Filter by name or email."), session: AsyncSession = Depends(get_db_session), ) -> ListResource[CustomerSchema]: """List customers.""" @@ -45,6 +46,7 @@ async def list( session, auth_subject, organization_id=organization_id, + query=query, pagination=pagination, sorting=sorting, ) diff --git a/server/polar/customer/schemas.py b/server/polar/customer/schemas.py index 66080d8e2b..be2ab8d1ae 100644 --- a/server/polar/customer/schemas.py +++ b/server/polar/customer/schemas.py @@ -1,7 +1,15 @@ -from pydantic import UUID4, Field +import hashlib + +from pydantic import UUID4, Field, computed_field from polar.kit.address import Address -from polar.kit.schemas import EmailStrDNS, IDSchema, Schema, TimestampedSchema +from polar.kit.metadata import MetadataOutputMixin +from polar.kit.schemas import ( + EmailStrDNS, + IDSchema, + Schema, + TimestampedSchema, +) from polar.kit.tax import TaxID from polar.organization.schemas import OrganizationID @@ -27,10 +35,19 @@ class CustomerUpdate(Schema): tax_id: TaxID | None = None -class Customer(IDSchema, TimestampedSchema): +class CustomerBase(MetadataOutputMixin, IDSchema, TimestampedSchema): email: str email_verified: bool name: str | None billing_address: Address | None tax_id: TaxID | None organization_id: UUID4 + + @computed_field + def avatar_url(self) -> str: + email_hash = hashlib.sha256(self.email.lower().encode()).hexdigest() + return f"https://www.gravatar.com/avatar/{email_hash}?d=blank" + + +class Customer(CustomerBase): + """A customer in an organization.""" diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index f30bddc8b0..6ff65d668a 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import Any -from sqlalchemy import Select, UnaryExpression, asc, desc, func, select +from sqlalchemy import Select, UnaryExpression, asc, desc, func, or_, select from stripe import Customer as StripeCustomer from polar.auth.models import AuthSubject, is_organization, is_user @@ -26,6 +26,7 @@ async def list( auth_subject: AuthSubject[User | Organization], *, organization_id: Sequence[uuid.UUID] | None = None, + query: str | None = None, pagination: PaginationParams, sorting: list[Sorting[CustomerSortProperty]] = [ (CustomerSortProperty.created_at, True) @@ -36,6 +37,14 @@ async def list( if organization_id is not None: statement = statement.where(Customer.organization_id.in_(organization_id)) + if query is not None: + statement = statement.where( + or_( + Customer.email.ilike(f"%{query}%"), + Customer.name.ilike(f"%{query}%"), + ) + ) + order_by_clauses: list[UnaryExpression[Any]] = [] for criterion, is_desc in sorting: clause_function = desc if is_desc else asc diff --git a/server/polar/license_key/schemas.py b/server/polar/license_key/schemas.py index 34866fb88d..b68d8f9ce5 100644 --- a/server/polar/license_key/schemas.py +++ b/server/polar/license_key/schemas.py @@ -5,11 +5,9 @@ from pydantic import UUID4, AliasPath, Field from polar.benefit.schemas import BenefitID +from polar.customer.schemas import CustomerBase from polar.exceptions import ResourceNotFound, Unauthorized -from polar.kit.address import Address -from polar.kit.metadata import MetadataOutputMixin -from polar.kit.schemas import IDSchema, Schema, TimestampedSchema -from polar.kit.tax import TaxID +from polar.kit.schemas import Schema from polar.kit.utils import generate_uuid, utc_now from polar.models.benefit import ( BenefitLicenseKeyActivationProperties, @@ -61,13 +59,7 @@ class LicenseKeyDeactivate(Schema): activation_id: UUID4 -class LicenseKeyCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): - email: str - email_verified: bool - name: str | None - billing_address: Address | None - tax_id: TaxID | None - organization_id: UUID4 +class LicenseKeyCustomer(CustomerBase): ... class LicenseKeyUser(Schema): diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index c7d710c1ec..4196ed29d2 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -4,11 +4,11 @@ from pydantic import UUID4, AliasPath, Field from polar.custom_field.data import CustomFieldDataOutputMixin +from polar.customer.schemas import CustomerBase from polar.discount.schemas import DiscountMinimal from polar.kit.address import Address from polar.kit.metadata import MetadataOutputMixin from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema -from polar.kit.tax import TaxID from polar.models.order import OrderBillingReason from polar.product.schemas import ProductBase, ProductPrice from polar.subscription.schemas import SubscriptionBase @@ -38,13 +38,7 @@ def get_amount_display(self) -> str: )}" -class OrderCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): - email: str - email_verified: bool - name: str | None - billing_address: Address | None - tax_id: TaxID | None - organization_id: UUID4 +class OrderCustomer(CustomerBase): ... class OrderUser(Schema): diff --git a/server/polar/storefront/schemas.py b/server/polar/storefront/schemas.py index 3de30284ee..8e3984d1c6 100644 --- a/server/polar/storefront/schemas.py +++ b/server/polar/storefront/schemas.py @@ -21,13 +21,13 @@ class ProductStorefront(ProductBase): ) -class Customer(Schema): +class StorefrontCustomer(Schema): name: str -class Customers(Schema): +class StorefrontCustomers(Schema): total: int - customers: list[Customer] + customers: list[StorefrontCustomer] class Storefront(Schema): @@ -36,4 +36,4 @@ class Storefront(Schema): organization: Organization products: list[ProductStorefront] donation_product: ProductStorefront | None - customers: Customers + customers: StorefrontCustomers diff --git a/server/polar/subscription/endpoints.py b/server/polar/subscription/endpoints.py index a61048c364..95cecda7c3 100644 --- a/server/polar/subscription/endpoints.py +++ b/server/polar/subscription/endpoints.py @@ -4,6 +4,7 @@ import structlog from fastapi import Depends, Query, Response from fastapi.responses import StreamingResponse +from pydantic import UUID4 from polar.kit.csv import ( IterableCSVWriter, @@ -46,6 +47,9 @@ async def list( product_id: MultipleQueryFilter[ProductID] | None = Query( None, title="ProductID Filter", description="Filter by product ID." ), + customer_id: MultipleQueryFilter[UUID4] | None = Query( + None, title="CustomerID Filter", description="Filter by customer ID." + ), discount_id: MultipleQueryFilter[ProductID] | None = Query( None, title="DiscountID Filter", description="Filter by discount ID." ), @@ -60,6 +64,7 @@ async def list( auth_subject, organization_id=organization_id, product_id=product_id, + customer_id=customer_id, discount_id=discount_id, active=active, pagination=pagination, diff --git a/server/polar/subscription/schemas.py b/server/polar/subscription/schemas.py index a10c3954a5..9cd5d91508 100644 --- a/server/polar/subscription/schemas.py +++ b/server/polar/subscription/schemas.py @@ -5,9 +5,9 @@ from pydantic import UUID4, AliasPath, Field from polar.custom_field.data import CustomFieldDataOutputMixin +from polar.customer.schemas import CustomerBase from polar.discount.schemas import DiscountMinimal from polar.enums import SubscriptionRecurringInterval -from polar.kit.address import Address from polar.kit.metadata import MetadataOutputMixin from polar.kit.schemas import ( EmailStrDNS, @@ -16,18 +16,11 @@ Schema, TimestampedSchema, ) -from polar.kit.tax import TaxID from polar.models.subscription import SubscriptionStatus from polar.product.schemas import Product, ProductPriceRecurring -class SubscriptionCustomer(IDSchema, TimestampedSchema, MetadataOutputMixin): - email: str - email_verified: bool - name: str | None - billing_address: Address | None - tax_id: TaxID | None - organization_id: UUID4 +class SubscriptionCustomer(CustomerBase): ... class SubscriptionUser(Schema): diff --git a/server/polar/subscription/service.py b/server/polar/subscription/service.py index e17d984109..9516a436ee 100644 --- a/server/polar/subscription/service.py +++ b/server/polar/subscription/service.py @@ -187,6 +187,7 @@ async def list( *, organization_id: Sequence[uuid.UUID] | None = None, product_id: Sequence[uuid.UUID] | None = None, + customer_id: Sequence[uuid.UUID] | None = None, discount_id: Sequence[uuid.UUID] | None = None, active: bool | None = None, pagination: PaginationParams, @@ -210,6 +211,9 @@ async def list( if product_id is not None: statement = statement.where(Product.id.in_(product_id)) + if customer_id is not None: + statement = statement.where(Subscription.customer_id.in_(customer_id)) + if discount_id is not None: statement = statement.where(Subscription.discount_id.in_(discount_id)) From cf394431c0f8f60df0dc656078d0ccbe7aa7b11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 13 Dec 2024 14:59:59 +0100 Subject: [PATCH 38/47] server: generate a customer session when sending order/subscription confirmation email --- server/polar/order/service.py | 7 ++++++- server/polar/subscription/service.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index 9e9a55c1d1..952d5fe87c 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -14,6 +14,7 @@ from polar.checkout.service import checkout as checkout_service from polar.config import settings from polar.customer.service import customer as customer_service +from polar.customer_session.service import customer_session as customer_session_service from polar.discount.service import discount as discount_service from polar.email.renderer import get_email_renderer from polar.email.sender import get_email_sender @@ -504,6 +505,10 @@ async def send_confirmation_email( product = order.product customer = order.customer + token, _ = await customer_session_service.create_customer_session( + session, customer + ) + subject, body = email_renderer.render_from_template( "Your {{ product.name }} order confirmation", "order/confirmation.html", @@ -511,7 +516,7 @@ async def send_confirmation_email( "featured_organization": organization, "product": product, "url": settings.generate_frontend_url( - f"/{organization.slug}/portal/orders/{order.id}" + f"/{organization.slug}/portal/orders/{order.id}?customer_session_token={token}" ), "current_year": datetime.now().year, }, diff --git a/server/polar/subscription/service.py b/server/polar/subscription/service.py index 9516a436ee..c045d3c156 100644 --- a/server/polar/subscription/service.py +++ b/server/polar/subscription/service.py @@ -18,6 +18,7 @@ from polar.checkout.service import checkout as checkout_service from polar.config import settings from polar.customer.service import customer as customer_service +from polar.customer_session.service import customer_session as customer_session_service from polar.discount.service import discount as discount_service from polar.email.renderer import get_email_renderer from polar.email.sender import get_email_sender @@ -690,6 +691,11 @@ async def send_confirmation_email( ) assert featured_organization is not None + customer = subscription.customer + token, _ = await customer_session_service.create_customer_session( + session, customer + ) + subject, body = email_renderer.render_from_template( "Your {{ product.name }} subscription", "subscription/confirmation.html", @@ -697,7 +703,7 @@ async def send_confirmation_email( "featured_organization": featured_organization, "product": product, "url": settings.generate_frontend_url( - f"/{featured_organization.slug}/portal/subscriptions/{subscription.id}" + f"/{featured_organization.slug}/portal/subscriptions/{subscription.id}?customer_session_token={token}" ), "current_year": datetime.now().year, }, From feb333de8fce93f3f8a1b0a196d816a79ced086c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 16 Dec 2024 08:42:23 +0100 Subject: [PATCH 39/47] server: make Customer.stripe_customer_id non unique Ease migration from Users. Also fixes the logic that were trying to tie a payment Transaction to a Customer/User. It was simply not working because most of the time, the Customer/User didn't exist at the time of handling the transaction. Now, we link it when handling the Order, which already had some logic to update that Transaction. --- .../2024-12-10-1357_migrate_to_customers.py | 17 ++++--- server/polar/customer/service.py | 9 ++-- server/polar/models/customer.py | 2 +- server/polar/order/service.py | 1 + server/polar/transaction/service/payment.py | 25 +++------ server/tests/order/test_service.py | 20 +++++++- server/tests/subscription/test_service.py | 4 +- .../tests/transaction/service/test_payment.py | 51 +++---------------- 8 files changed, 52 insertions(+), 77 deletions(-) diff --git a/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py b/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py index 8d61c654d6..e9f0c9a55f 100644 --- a/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py +++ b/server/migrations/versions/2024-12-10-1357_migrate_to_customers.py @@ -58,9 +58,6 @@ def upgrade() -> None: ondelete="set null", ), sa.PrimaryKeyConstraint("id", name=op.f("customers_pkey")), - sa.UniqueConstraint( - "stripe_customer_id", name=op.f("customers_stripe_customer_id_key") - ), ) op.create_index( op.f("ix_customers_created_at"), "customers", ["created_at"], unique=False @@ -112,7 +109,7 @@ def upgrade() -> None: users.created_at, users.email, users.email_verified, - NULL, + users.stripe_customer_id, NULL, NULL, NULL, @@ -149,7 +146,7 @@ def upgrade() -> None: users.created_at, users.email, users.email_verified, - NULL, + users.stripe_customer_id, NULL, NULL, NULL, @@ -190,7 +187,7 @@ def upgrade() -> None: users.created_at, users.email, users.email_verified, - NULL, + users.stripe_customer_id, NULL, NULL, NULL, @@ -580,6 +577,14 @@ def upgrade() -> None: WHERE payment_customer_id IS NOT NULL """ ) + op.execute( + """ + UPDATE transactions + SET payment_customer_id = orders.customer_id + FROM orders + WHERE orders.id = transactions.order_id AND transactions.payment_customer_id IS NULL + """ + ) op.create_index( op.f("ix_transactions_payment_customer_id"), "transactions", diff --git a/server/polar/customer/service.py b/server/polar/customer/service.py index 6ff65d668a..47ba51c501 100644 --- a/server/polar/customer/service.py +++ b/server/polar/customer/service.py @@ -215,12 +215,13 @@ async def get_by_user( result = await session.execute(statement) return result.unique().scalars().all() - async def get_by_stripe_customer_id( - self, session: AsyncSession, stripe_customer_id: str + async def get_by_stripe_customer_id_and_organization( + self, session: AsyncSession, stripe_customer_id: str, organization: Organization ) -> Customer | None: statement = select(Customer).where( Customer.deleted_at.is_(None), Customer.stripe_customer_id == stripe_customer_id, + Customer.organization_id == organization.id, ) result = await session.execute(statement) return result.scalar_one_or_none() @@ -238,7 +239,9 @@ async def get_or_create_from_stripe_customer( If the customer does not exist, create a new one. """ - customer = await self.get_by_stripe_customer_id(session, stripe_customer.id) + customer = await self.get_by_stripe_customer_id_and_organization( + session, stripe_customer.id, organization + ) assert stripe_customer.email is not None if customer is None: customer = await self.get_by_email_and_organization( diff --git a/server/polar/models/customer.py b/server/polar/models/customer.py index 8499711a59..1902d8791d 100644 --- a/server/polar/models/customer.py +++ b/server/polar/models/customer.py @@ -77,7 +77,7 @@ class Customer(MetadataMixin, RecordModel): email: Mapped[str] = mapped_column(String(320), nullable=False) email_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) stripe_customer_id: Mapped[str | None] = mapped_column( - String, nullable=True, default=None, unique=True + String, nullable=True, default=None, unique=False ) name: Mapped[str | None] = mapped_column(String, nullable=True, default=None) diff --git a/server/polar/order/service.py b/server/polar/order/service.py index 952d5fe87c..e9c90734e1 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -564,6 +564,7 @@ async def _create_order_balance( transfer_amount = payment_transaction.amount payment_transaction.order = order + payment_transaction.payment_customer = order.customer session.add(payment_transaction) # Prepare an held balance diff --git a/server/polar/transaction/service/payment.py b/server/polar/transaction/service/payment.py index 20c670b2e6..451039cfa3 100644 --- a/server/polar/transaction/service/payment.py +++ b/server/polar/transaction/service/payment.py @@ -3,13 +3,11 @@ import stripe as stripe_lib from sqlalchemy import select -from polar.customer.service import customer as customer_service from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import stripe as stripe_service from polar.integrations.stripe.utils import get_expandable_id -from polar.models import Pledge, Transaction, User +from polar.models import Organization, Pledge, Transaction, User from polar.models.transaction import PaymentProcessor, TransactionType -from polar.organization.service import organization as organization_service from polar.pledge.service import pledge as pledge_service from polar.postgres import AsyncSession @@ -44,19 +42,6 @@ async def create_payment( if existing_transaction is not None: return existing_transaction - # Retrieve customer - customer_id = None - payment_customer = None - payment_organization = None - if charge.customer: - customer_id = get_expandable_id(charge.customer) - payment_customer = await customer_service.get_by_stripe_customer_id( - session, customer_id - ) - payment_organization = await organization_service.get_by( - session, stripe_customer_id=customer_id - ) - # Retrieve tax amount and country tax_amount = 0 tax_country = None @@ -87,6 +72,7 @@ async def create_payment( # Try to link with a Pledge payment_user: User | None = None + payment_organization: Organization | None = None if pledge_invoice or charge.metadata.get("type") == ProductType.pledge: assert charge.payment_intent is not None payment_intent = get_expandable_id(charge.payment_intent) @@ -96,7 +82,7 @@ async def create_payment( raise PledgeDoesNotExist(charge.id, payment_intent) # If we were not able to link to a payer by Stripe Customer ID, # link from the pledge data. Happens for anonymous pledges. - if payment_customer is None and payment_organization is None: + if payment_organization is None: await session.refresh(pledge, {"user", "by_organization"}) payment_user = pledge.user payment_organization = pledge.by_organization @@ -112,15 +98,16 @@ async def create_payment( tax_amount=tax_amount, tax_country=tax_country, tax_state=tax_state, - customer_id=customer_id, - payment_customer=payment_customer, + customer_id=get_expandable_id(charge.customer) if charge.customer else None, payment_organization=payment_organization, payment_user=payment_user, charge_id=charge.id, pledge=pledge, risk_level=risk.get("risk_level"), risk_score=risk.get("risk_score"), + # Filled when we handle the invoice order=None, + payment_customer=None, ) # Compute and link fees diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index 1d50359709..aee9ba8a0b 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -7,6 +7,7 @@ import pytest import stripe as stripe_lib from pytest_mock import MockerFixture +from sqlalchemy.orm import joinedload from polar.auth.models import AuthSubject from polar.held_balance.service import held_balance as held_balance_service @@ -356,10 +357,13 @@ async def test_subscription_no_account( assert held_balance.order_id == order.id updated_payment_transaction = await payment_transaction_service.get( - session, id=payment_transaction.id + session, + id=payment_transaction.id, + options=(joinedload(Transaction.payment_customer),), ) assert updated_payment_transaction is not None assert updated_payment_transaction.order_id == order.id + assert updated_payment_transaction.payment_customer == order.customer enqueue_job_mock.assert_called_once_with( "order.discord_notification", @@ -515,10 +519,13 @@ async def test_subscription_with_account( platform_fee_transaction_service_mock.create_fees_reversal_balances.assert_called_once() updated_payment_transaction = await payment_transaction_service.get( - session, id=payment_transaction.id + session, + id=payment_transaction.id, + options=(joinedload(Transaction.payment_customer),), ) assert updated_payment_transaction is not None assert updated_payment_transaction.order_id == order.id + assert updated_payment_transaction.payment_customer == order.customer enqueue_job_mock.assert_called_once_with( "order.discord_notification", @@ -669,6 +676,15 @@ async def test_one_time_product( assert order.billing_address == Address(country="FR") # pyright: ignore assert order.created_at == created_datetime + updated_payment_transaction = await payment_transaction_service.get( + session, + id=payment_transaction.id, + options=(joinedload(Transaction.payment_customer),), + ) + assert updated_payment_transaction is not None + assert updated_payment_transaction.order_id == order.id + assert updated_payment_transaction.payment_customer == order.customer + enqueue_job_mock.assert_any_call( "order.discord_notification", order_id=order.id, diff --git a/server/tests/subscription/test_service.py b/server/tests/subscription/test_service.py index 3158ac557c..0b7d38b2c1 100644 --- a/server/tests/subscription/test_service.py +++ b/server/tests/subscription/test_service.py @@ -271,8 +271,8 @@ async def test_new_customer( assert subscription.stripe_subscription_id == stripe_subscription.id assert subscription.product_id == product.id - customer = await customer_service.get_by_stripe_customer_id( - session, stripe_customer.id + customer = await customer_service.get_by_stripe_customer_id_and_organization( + session, stripe_customer.id, product.organization ) assert customer is not None assert customer.email == stripe_customer.email diff --git a/server/tests/transaction/service/test_payment.py b/server/tests/transaction/service/test_payment.py index cbe31100a4..4a60f06697 100644 --- a/server/tests/transaction/service/test_payment.py +++ b/server/tests/transaction/service/test_payment.py @@ -8,7 +8,7 @@ from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService -from polar.models import Customer, Organization, Pledge, Transaction, User +from polar.models import Customer, Pledge, Transaction from polar.models.transaction import PaymentProcessor, TransactionType from polar.postgres import AsyncSession from polar.transaction.service.payment import ( # type: ignore[attr-defined] @@ -101,12 +101,11 @@ async def test_existing_transaction( session: AsyncSession, save_fixture: SaveFixture, pledge: Pledge, - user: User, - stripe_service_mock: MagicMock, + customer: Customer, ) -> None: stripe_balance_transaction = build_stripe_balance_transaction() stripe_charge = build_stripe_charge( - customer=user.stripe_customer_id, + customer=customer.stripe_customer_id, payment_intent=pledge.payment_id, balance_transaction=stripe_balance_transaction.id, ) @@ -174,49 +173,12 @@ async def test_customer( ) assert transaction.type == TransactionType.payment - assert transaction.customer_id == customer.stripe_customer_id - assert transaction.payment_customer == customer + assert transaction.customer_id == stripe_charge.customer + assert transaction.payment_customer is None assert transaction.payment_organization is None assert transaction.risk_level == risk_level assert transaction.risk_score == risk_score - async def test_customer_organization( - self, - session: AsyncSession, - save_fixture: SaveFixture, - pledge: Pledge, - organization: Organization, - stripe_service_mock: MagicMock, - ) -> None: - organization.stripe_customer_id = "STRIPE_CUSTOMER_ID" - await save_fixture(organization) - pledge.by_organization = organization - pledge.payment_id = "STRIPE_PAYMENT_ID" - await save_fixture(pledge) - - stripe_balance_transaction = build_stripe_balance_transaction() - stripe_charge = build_stripe_charge( - customer=organization.stripe_customer_id, - payment_intent=pledge.payment_id, - balance_transaction=stripe_balance_transaction.id, - ) - - stripe_service_mock.get_balance_transaction.return_value = ( - stripe_balance_transaction - ) - - # then - session.expunge_all() - - transaction = await payment_transaction_service.create_payment( - session, charge=stripe_charge - ) - - assert transaction.type == TransactionType.payment - assert transaction.customer_id == organization.stripe_customer_id - assert transaction.payment_customer is None - assert transaction.payment_organization == organization - async def test_not_existing_pledge( self, session: AsyncSession, pledge: Pledge, stripe_service_mock: MagicMock ) -> None: @@ -318,7 +280,8 @@ async def test_anonymous_pledge( assert transaction.type == TransactionType.payment assert transaction.pledge == pledge - assert transaction.payment_customer == pledge.user + assert transaction.payment_customer is None + assert transaction.payment_user == pledge.user assert transaction.payment_organization == pledge.by_organization async def test_tax_metadata( From 8930d71b0a4929ff50af097c0c7d6891d8baca8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 11 Dec 2024 09:59:51 +0100 Subject: [PATCH 40/47] clients/sdk: update OpenAPI client --- .../sdk/src/client/.openapi-generator/FILES | 16 +- clients/packages/sdk/src/client/PolarAPI.ts | 48 +- .../sdk/src/client/apis/BenefitsApi.ts | 12 +- .../sdk/src/client/apis/CheckoutsCustomApi.ts | 13 +- .../apis/CustomerPortalBenefitGrantsApi.ts | 251 + .../apis/CustomerPortalCustomerSessionApi.ts | 113 + .../client/apis/CustomerPortalCustomersApi.ts | 83 + .../apis/CustomerPortalDownloadablesApi.ts | 135 + .../apis/CustomerPortalLicenseKeysApi.ts | 282 + .../apis/CustomerPortalOauthAccountsApi.ts | 165 + .../client/apis/CustomerPortalOrdersApi.ts | 232 + .../apis/CustomerPortalOrganizationsApi.ts | 67 + .../apis/CustomerPortalSubscriptionsApi.ts | 292 + .../sdk/src/client/apis/CustomersApi.ts | 298 + .../sdk/src/client/apis/LicenseKeysApi.ts | 4 +- .../packages/sdk/src/client/apis/OrdersApi.ts | 8 +- .../sdk/src/client/apis/ProductsApi.ts | 4 +- .../sdk/src/client/apis/SubscriptionsApi.ts | 6 + .../sdk/src/client/apis/TransactionsApi.ts | 11 +- clients/packages/sdk/src/client/apis/index.ts | 16 +- .../packages/sdk/src/client/models/index.ts | 15155 +++++++++------- 21 files changed, 10122 insertions(+), 7089 deletions(-) create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalBenefitGrantsApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalCustomerSessionApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalCustomersApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalDownloadablesApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalLicenseKeysApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalOauthAccountsApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalOrdersApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalOrganizationsApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomerPortalSubscriptionsApi.ts create mode 100644 clients/packages/sdk/src/client/apis/CustomersApi.ts diff --git a/clients/packages/sdk/src/client/.openapi-generator/FILES b/clients/packages/sdk/src/client/.openapi-generator/FILES index 92a94b0373..51b4a13f2e 100644 --- a/clients/packages/sdk/src/client/.openapi-generator/FILES +++ b/clients/packages/sdk/src/client/.openapi-generator/FILES @@ -7,6 +7,16 @@ apis/CheckoutLinksApi.ts apis/CheckoutsApi.ts apis/CheckoutsCustomApi.ts apis/CustomFieldsApi.ts +apis/CustomerPortalBenefitGrantsApi.ts +apis/CustomerPortalCustomerSessionApi.ts +apis/CustomerPortalCustomersApi.ts +apis/CustomerPortalDownloadablesApi.ts +apis/CustomerPortalLicenseKeysApi.ts +apis/CustomerPortalOauthAccountsApi.ts +apis/CustomerPortalOrdersApi.ts +apis/CustomerPortalOrganizationsApi.ts +apis/CustomerPortalSubscriptionsApi.ts +apis/CustomersApi.ts apis/DashboardApi.ts apis/DiscountsApi.ts apis/EmbedsApi.ts @@ -35,13 +45,7 @@ apis/RewardsApi.ts apis/StorefrontsApi.ts apis/SubscriptionsApi.ts apis/TransactionsApi.ts -apis/UsersAdvertisementsApi.ts apis/UsersApi.ts -apis/UsersBenefitsApi.ts -apis/UsersDownloadablesApi.ts -apis/UsersLicenseKeysApi.ts -apis/UsersOrdersApi.ts -apis/UsersSubscriptionsApi.ts apis/WebhooksApi.ts apis/index.ts models/index.ts diff --git a/clients/packages/sdk/src/client/PolarAPI.ts b/clients/packages/sdk/src/client/PolarAPI.ts index 1635ca894e..6513a7b94d 100644 --- a/clients/packages/sdk/src/client/PolarAPI.ts +++ b/clients/packages/sdk/src/client/PolarAPI.ts @@ -35,14 +35,18 @@ import { TransactionsApi, UsersApi, WebhooksApi, - UsersAdvertisementsApi, - UsersBenefitsApi, - UsersDownloadablesApi, - UsersLicenseKeysApi, - UsersOrdersApi, - UsersSubscriptionsApi, LicenseKeysApi, CheckoutLinksApi, + CustomerPortalBenefitGrantsApi, + CustomerPortalCustomerSessionApi, + CustomerPortalCustomersApi, + CustomerPortalDownloadablesApi, + CustomerPortalLicenseKeysApi, + CustomerPortalOauthAccountsApi, + CustomerPortalOrdersApi, + CustomerPortalOrganizationsApi, + CustomerPortalSubscriptionsApi, + CustomersApi, } from '.' export class PolarAPI { @@ -50,6 +54,16 @@ export class PolarAPI { public readonly advertisements: AdvertisementsApi public readonly auth: AuthApi public readonly backoffice: BackofficeApi + public readonly customers: CustomersApi + public readonly customerPortalBenefitGrants: CustomerPortalBenefitGrantsApi + public readonly customerPortalCustomers: CustomerPortalCustomersApi + public readonly customerPortalCustomerSession: CustomerPortalCustomerSessionApi + public readonly customerPortalDownloadables: CustomerPortalDownloadablesApi + public readonly customerPortalLicenseKeys: CustomerPortalLicenseKeysApi + public readonly customerPortalOauthAccounts: CustomerPortalOauthAccountsApi + public readonly customerPortalOrders: CustomerPortalOrdersApi + public readonly customerPortalOrganizations: CustomerPortalOrganizationsApi + public readonly customerPortalSubscriptions: CustomerPortalSubscriptionsApi public readonly legacyCheckouts: CheckoutsApi public readonly checkouts: CheckoutsCustomApi public readonly checkoutLinks: CheckoutLinksApi @@ -80,13 +94,7 @@ export class PolarAPI { public readonly storefronts: StorefrontsApi public readonly subscriptions: SubscriptionsApi public readonly transactions: TransactionsApi - public readonly usersAdvertisements: UsersAdvertisementsApi public readonly users: UsersApi - public readonly usersBenefits: UsersBenefitsApi - public readonly usersDownloadables: UsersDownloadablesApi - public readonly usersLicenseKeys: UsersLicenseKeysApi - public readonly usersOrders: UsersOrdersApi - public readonly usersSubscriptions: UsersSubscriptionsApi public readonly webhooks: WebhooksApi public readonly files: FilesApi @@ -95,6 +103,16 @@ export class PolarAPI { this.advertisements = new AdvertisementsApi(config) this.auth = new AuthApi(config) this.backoffice = new BackofficeApi(config) + this.customers = new CustomersApi(config) + this.customerPortalBenefitGrants = new CustomerPortalBenefitGrantsApi(config) + this.customerPortalCustomers = new CustomerPortalCustomersApi(config) + this.customerPortalCustomerSession = new CustomerPortalCustomerSessionApi(config) + this.customerPortalDownloadables = new CustomerPortalDownloadablesApi(config) + this.customerPortalLicenseKeys = new CustomerPortalLicenseKeysApi(config) + this.customerPortalOauthAccounts = new CustomerPortalOauthAccountsApi(config) + this.customerPortalOrders = new CustomerPortalOrdersApi(config) + this.customerPortalOrganizations = new CustomerPortalOrganizationsApi(config) + this.customerPortalSubscriptions = new CustomerPortalSubscriptionsApi(config) this.legacyCheckouts = new CheckoutsApi(config) this.checkouts = new CheckoutsCustomApi(config) this.checkoutLinks = new CheckoutLinksApi(config) @@ -126,13 +144,7 @@ export class PolarAPI { this.storefronts = new StorefrontsApi(config) this.subscriptions = new SubscriptionsApi(config) this.transactions = new TransactionsApi(config) - this.usersAdvertisements = new UsersAdvertisementsApi(config) this.users = new UsersApi(config) - this.usersBenefits = new UsersBenefitsApi(config) - this.usersDownloadables = new UsersDownloadablesApi(config) - this.usersLicenseKeys = new UsersLicenseKeysApi(config) - this.usersOrders = new UsersOrdersApi(config) - this.usersSubscriptions = new UsersSubscriptionsApi(config) this.webhooks = new WebhooksApi(config) this.files = new FilesApi(config) } diff --git a/clients/packages/sdk/src/client/apis/BenefitsApi.ts b/clients/packages/sdk/src/client/apis/BenefitsApi.ts index 200bf83496..1669031f8e 100644 --- a/clients/packages/sdk/src/client/apis/BenefitsApi.ts +++ b/clients/packages/sdk/src/client/apis/BenefitsApi.ts @@ -19,6 +19,7 @@ import type { BenefitCreate, BenefitTypeFilter, BenefitUpdate, + CustomerIDFilter1, HTTPValidationError, ListResourceBenefit, ListResourceBenefitGrant, @@ -42,8 +43,7 @@ export interface BenefitsApiGetRequest { export interface BenefitsApiGrantsRequest { id: string; isGranted?: boolean; - userId?: string; - githubUserId?: number; + customerId?: CustomerIDFilter1; page?: number; limit?: number; } @@ -214,12 +214,8 @@ export class BenefitsApi extends runtime.BaseAPI { queryParameters['is_granted'] = requestParameters['isGranted']; } - if (requestParameters['userId'] != null) { - queryParameters['user_id'] = requestParameters['userId']; - } - - if (requestParameters['githubUserId'] != null) { - queryParameters['github_user_id'] = requestParameters['githubUserId']; + if (requestParameters['customerId'] != null) { + queryParameters['customer_id'] = requestParameters['customerId']; } if (requestParameters['page'] != null) { diff --git a/clients/packages/sdk/src/client/apis/CheckoutsCustomApi.ts b/clients/packages/sdk/src/client/apis/CheckoutsCustomApi.ts index 0ba504c34b..d08f119b6b 100644 --- a/clients/packages/sdk/src/client/apis/CheckoutsCustomApi.ts +++ b/clients/packages/sdk/src/client/apis/CheckoutsCustomApi.ts @@ -20,6 +20,7 @@ import type { CheckoutCreate, CheckoutCreatePublic, CheckoutPublic, + CheckoutPublicConfirmed, CheckoutSortProperty, CheckoutUpdate, CheckoutUpdatePublic, @@ -78,7 +79,7 @@ export class CheckoutsCustomApi extends runtime.BaseAPI { * Confirm a checkout session by client secret. Orders and subscriptions will be processed. * Confirm Checkout Session from Client */ - async clientConfirmRaw(requestParameters: CheckoutsCustomApiClientConfirmRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + async clientConfirmRaw(requestParameters: CheckoutsCustomApiClientConfirmRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { if (requestParameters['clientSecret'] == null) { throw new runtime.RequiredError( 'clientSecret', @@ -99,6 +100,14 @@ export class CheckoutsCustomApi extends runtime.BaseAPI { headerParameters['Content-Type'] = 'application/json'; + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } const response = await this.request({ path: `/v1/checkouts/custom/client/{client_secret}/confirm`.replace(`{${"client_secret"}}`, encodeURIComponent(String(requestParameters['clientSecret']))), method: 'POST', @@ -114,7 +123,7 @@ export class CheckoutsCustomApi extends runtime.BaseAPI { * Confirm a checkout session by client secret. Orders and subscriptions will be processed. * Confirm Checkout Session from Client */ - async clientConfirm(requestParameters: CheckoutsCustomApiClientConfirmRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + async clientConfirm(requestParameters: CheckoutsCustomApiClientConfirmRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { const response = await this.clientConfirmRaw(requestParameters, initOverrides); return await response.value(); } diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalBenefitGrantsApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalBenefitGrantsApi.ts new file mode 100644 index 0000000000..21f2a86b3d --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalBenefitGrantsApi.ts @@ -0,0 +1,251 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + BenefitIDFilter2, + BenefitTypeFilter, + CheckoutIDFilter, + CustomerBenefitGrant, + CustomerBenefitGrantSortProperty, + CustomerBenefitGrantUpdate, + HTTPValidationError, + ListResourceCustomerBenefitGrant, + NotPermitted, + OrderIDFilter, + OrganizationIDFilter, + ResourceNotFound, + SubscriptionIDFilter, +} from '../models/index'; + +export interface CustomerPortalBenefitGrantsApiGetRequest { + id: string; +} + +export interface CustomerPortalBenefitGrantsApiListRequest { + type?: BenefitTypeFilter; + benefitId?: BenefitIDFilter2; + organizationId?: OrganizationIDFilter; + checkoutId?: CheckoutIDFilter; + orderId?: OrderIDFilter; + subscriptionId?: SubscriptionIDFilter; + page?: number; + limit?: number; + sorting?: Array; +} + +export interface CustomerPortalBenefitGrantsApiUpdateRequest { + id: string; + body: CustomerBenefitGrantUpdate; +} + +/** + * + */ +export class CustomerPortalBenefitGrantsApi extends runtime.BaseAPI { + + /** + * Get a benefit grant by ID for the authenticated customer or user. + * Get Benefit Grant + */ + async getRaw(requestParameters: CustomerPortalBenefitGrantsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/benefit-grants/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a benefit grant by ID for the authenticated customer or user. + * Get Benefit Grant + */ + async get(requestParameters: CustomerPortalBenefitGrantsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List benefits grants of the authenticated customer or user. + * List Benefit Grants + */ + async listRaw(requestParameters: CustomerPortalBenefitGrantsApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['type'] != null) { + queryParameters['type'] = requestParameters['type']; + } + + if (requestParameters['benefitId'] != null) { + queryParameters['benefit_id'] = requestParameters['benefitId']; + } + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['checkoutId'] != null) { + queryParameters['checkout_id'] = requestParameters['checkoutId']; + } + + if (requestParameters['orderId'] != null) { + queryParameters['order_id'] = requestParameters['orderId']; + } + + if (requestParameters['subscriptionId'] != null) { + queryParameters['subscription_id'] = requestParameters['subscriptionId']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + if (requestParameters['sorting'] != null) { + queryParameters['sorting'] = requestParameters['sorting']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/benefit-grants/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List benefits grants of the authenticated customer or user. + * List Benefit Grants + */ + async list(requestParameters: CustomerPortalBenefitGrantsApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Update a benefit grant for the authenticated customer or user. + * Update Benefit Grant + */ + async updateRaw(requestParameters: CustomerPortalBenefitGrantsApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling update().' + ); + } + + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling update().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/benefit-grants/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'PATCH', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Update a benefit grant for the authenticated customer or user. + * Update Benefit Grant + */ + async update(requestParameters: CustomerPortalBenefitGrantsApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.updateRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalCustomerSessionApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalCustomerSessionApi.ts new file mode 100644 index 0000000000..3ee559b3a9 --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalCustomerSessionApi.ts @@ -0,0 +1,113 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + CustomerSessionCodeAuthenticateRequest, + CustomerSessionCodeAuthenticateResponse, + CustomerSessionCodeRequest, + HTTPValidationError, +} from '../models/index'; + +export interface CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionAuthenticateRequest { + body: CustomerSessionCodeAuthenticateRequest; +} + +export interface CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionRequestRequest { + body: CustomerSessionCodeRequest; +} + +/** + * + */ +export class CustomerPortalCustomerSessionApi extends runtime.BaseAPI { + + /** + * Customer Portal.Customer Session.Authenticate + */ + async customerPortalCustomerSessionAuthenticateRaw(requestParameters: CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionAuthenticateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling customerPortalCustomerSessionAuthenticate().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + const response = await this.request({ + path: `/v1/customer-portal/customer-session/authenticate`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Customer Portal.Customer Session.Authenticate + */ + async customerPortalCustomerSessionAuthenticate(requestParameters: CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionAuthenticateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.customerPortalCustomerSessionAuthenticateRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Customer Portal.Customer Session.Request + */ + async customerPortalCustomerSessionRequestRaw(requestParameters: CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionRequestRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling customerPortalCustomerSessionRequest().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + const response = await this.request({ + path: `/v1/customer-portal/customer-session/request`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + if (this.isJsonMime(response.headers.get('content-type'))) { + return new runtime.JSONApiResponse(response); + } else { + return new runtime.TextApiResponse(response) as any; + } + } + + /** + * Customer Portal.Customer Session.Request + */ + async customerPortalCustomerSessionRequest(requestParameters: CustomerPortalCustomerSessionApiCustomerPortalCustomerSessionRequestRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.customerPortalCustomerSessionRequestRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalCustomersApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalCustomersApi.ts new file mode 100644 index 0000000000..056ee00787 --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalCustomersApi.ts @@ -0,0 +1,83 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + CustomerPortalCustomer, + HTTPValidationError, + ResourceNotFound, +} from '../models/index'; + +export interface CustomerPortalCustomersApiGetRequest { + id: string; +} + +/** + * + */ +export class CustomerPortalCustomersApi extends runtime.BaseAPI { + + /** + * Get a customer by ID for the authenticated customer or user. + * Get Customer + */ + async getRaw(requestParameters: CustomerPortalCustomersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/customers/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a customer by ID for the authenticated customer or user. + * Get Customer + */ + async get(requestParameters: CustomerPortalCustomersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalDownloadablesApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalDownloadablesApi.ts new file mode 100644 index 0000000000..70fd104043 --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalDownloadablesApi.ts @@ -0,0 +1,135 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + BenefitIDFilter1, + HTTPValidationError, + ListResourceDownloadableRead, + OrganizationIDFilter, +} from '../models/index'; + +export interface CustomerPortalDownloadablesApiCustomerPortalDownloadablesGetRequest { + token: string; +} + +export interface CustomerPortalDownloadablesApiListRequest { + organizationId?: OrganizationIDFilter; + benefitId?: BenefitIDFilter1; + page?: number; + limit?: number; +} + +/** + * + */ +export class CustomerPortalDownloadablesApi extends runtime.BaseAPI { + + /** + * Get Downloadable + */ + async customerPortalDownloadablesGetRaw(requestParameters: CustomerPortalDownloadablesApiCustomerPortalDownloadablesGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['token'] == null) { + throw new runtime.RequiredError( + 'token', + 'Required parameter "token" was null or undefined when calling customerPortalDownloadablesGet().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + const response = await this.request({ + path: `/v1/customer-portal/downloadables/{token}`.replace(`{${"token"}}`, encodeURIComponent(String(requestParameters['token']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + if (this.isJsonMime(response.headers.get('content-type'))) { + return new runtime.JSONApiResponse(response); + } else { + return new runtime.TextApiResponse(response) as any; + } + } + + /** + * Get Downloadable + */ + async customerPortalDownloadablesGet(requestParameters: CustomerPortalDownloadablesApiCustomerPortalDownloadablesGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.customerPortalDownloadablesGetRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List Downloadables + */ + async listRaw(requestParameters: CustomerPortalDownloadablesApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['benefitId'] != null) { + queryParameters['benefit_id'] = requestParameters['benefitId']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/downloadables/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List Downloadables + */ + async list(requestParameters: CustomerPortalDownloadablesApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalLicenseKeysApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalLicenseKeysApi.ts new file mode 100644 index 0000000000..065212db89 --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalLicenseKeysApi.ts @@ -0,0 +1,282 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + HTTPValidationError, + LicenseKeyActivate, + LicenseKeyActivationRead, + LicenseKeyDeactivate, + LicenseKeyValidate, + LicenseKeyWithActivations, + ListResourceLicenseKeyRead, + NotPermitted, + OrganizationIDFilter, + ResourceNotFound, + Unauthorized, + ValidatedLicenseKey, +} from '../models/index'; + +export interface CustomerPortalLicenseKeysApiActivateRequest { + body: LicenseKeyActivate; +} + +export interface CustomerPortalLicenseKeysApiDeactivateRequest { + body: LicenseKeyDeactivate; +} + +export interface CustomerPortalLicenseKeysApiGetRequest { + id: string; +} + +export interface CustomerPortalLicenseKeysApiListRequest { + organizationId?: OrganizationIDFilter; + benefitId?: string; + page?: number; + limit?: number; +} + +export interface CustomerPortalLicenseKeysApiValidateRequest { + body: LicenseKeyValidate; +} + +/** + * + */ +export class CustomerPortalLicenseKeysApi extends runtime.BaseAPI { + + /** + * Activate a license key instance. + * Activate License Key + */ + async activateRaw(requestParameters: CustomerPortalLicenseKeysApiActivateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling activate().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + const response = await this.request({ + path: `/v1/customer-portal/license-keys/activate`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Activate a license key instance. + * Activate License Key + */ + async activate(requestParameters: CustomerPortalLicenseKeysApiActivateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.activateRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Deactivate a license key instance. + * Deactivate License Key + */ + async deactivateRaw(requestParameters: CustomerPortalLicenseKeysApiDeactivateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling deactivate().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + const response = await this.request({ + path: `/v1/customer-portal/license-keys/deactivate`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.VoidApiResponse(response); + } + + /** + * Deactivate a license key instance. + * Deactivate License Key + */ + async deactivate(requestParameters: CustomerPortalLicenseKeysApiDeactivateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + await this.deactivateRaw(requestParameters, initOverrides); + } + + /** + * Get a license key. + * Get License Key + */ + async getRaw(requestParameters: CustomerPortalLicenseKeysApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/license-keys/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a license key. + * Get License Key + */ + async get(requestParameters: CustomerPortalLicenseKeysApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List License Keys + */ + async listRaw(requestParameters: CustomerPortalLicenseKeysApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['benefitId'] != null) { + queryParameters['benefit_id'] = requestParameters['benefitId']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/license-keys/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List License Keys + */ + async list(requestParameters: CustomerPortalLicenseKeysApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Validate a license key. + * Validate License Key + */ + async validateRaw(requestParameters: CustomerPortalLicenseKeysApiValidateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling validate().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + const response = await this.request({ + path: `/v1/customer-portal/license-keys/validate`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Validate a license key. + * Validate License Key + */ + async validate(requestParameters: CustomerPortalLicenseKeysApiValidateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.validateRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalOauthAccountsApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalOauthAccountsApi.ts new file mode 100644 index 0000000000..fc32004388 --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalOauthAccountsApi.ts @@ -0,0 +1,165 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + AuthorizeResponse, + CustomerOAuthPlatform, + HTTPValidationError, +} from '../models/index'; + +export interface CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsAuthorizeRequest { + platform: CustomerOAuthPlatform; + customerId: string; + returnTo?: string; +} + +export interface CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsCallbackRequest { + state: string; + code?: string; + error?: string; +} + +/** + * + */ +export class CustomerPortalOauthAccountsApi extends runtime.BaseAPI { + + /** + * Customer Portal.Oauth Accounts.Authorize + */ + async customerPortalOauthAccountsAuthorizeRaw(requestParameters: CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsAuthorizeRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['platform'] == null) { + throw new runtime.RequiredError( + 'platform', + 'Required parameter "platform" was null or undefined when calling customerPortalOauthAccountsAuthorize().' + ); + } + + if (requestParameters['customerId'] == null) { + throw new runtime.RequiredError( + 'customerId', + 'Required parameter "customerId" was null or undefined when calling customerPortalOauthAccountsAuthorize().' + ); + } + + const queryParameters: any = {}; + + if (requestParameters['platform'] != null) { + queryParameters['platform'] = requestParameters['platform']; + } + + if (requestParameters['customerId'] != null) { + queryParameters['customer_id'] = requestParameters['customerId']; + } + + if (requestParameters['returnTo'] != null) { + queryParameters['return_to'] = requestParameters['returnTo']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/oauth-accounts/authorize`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Customer Portal.Oauth Accounts.Authorize + */ + async customerPortalOauthAccountsAuthorize(requestParameters: CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsAuthorizeRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.customerPortalOauthAccountsAuthorizeRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Customer Portal.Oauth Accounts.Callback + */ + async customerPortalOauthAccountsCallbackRaw(requestParameters: CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsCallbackRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['state'] == null) { + throw new runtime.RequiredError( + 'state', + 'Required parameter "state" was null or undefined when calling customerPortalOauthAccountsCallback().' + ); + } + + const queryParameters: any = {}; + + if (requestParameters['state'] != null) { + queryParameters['state'] = requestParameters['state']; + } + + if (requestParameters['code'] != null) { + queryParameters['code'] = requestParameters['code']; + } + + if (requestParameters['error'] != null) { + queryParameters['error'] = requestParameters['error']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/oauth-accounts/callback`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + if (this.isJsonMime(response.headers.get('content-type'))) { + return new runtime.JSONApiResponse(response); + } else { + return new runtime.TextApiResponse(response) as any; + } + } + + /** + * Customer Portal.Oauth Accounts.Callback + */ + async customerPortalOauthAccountsCallback(requestParameters: CustomerPortalOauthAccountsApiCustomerPortalOauthAccountsCallbackRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.customerPortalOauthAccountsCallbackRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalOrdersApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalOrdersApi.ts new file mode 100644 index 0000000000..fdc65488fa --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalOrdersApi.ts @@ -0,0 +1,232 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + CustomerOrder, + CustomerOrderInvoice, + CustomerOrderSortProperty, + HTTPValidationError, + ListResourceCustomerOrder, + OrganizationIDFilter, + ProductIDFilter, + ProductPriceTypeFilter, + ResourceNotFound, + SubscriptionIDFilter, +} from '../models/index'; + +export interface CustomerPortalOrdersApiGetRequest { + id: string; +} + +export interface CustomerPortalOrdersApiInvoiceRequest { + id: string; +} + +export interface CustomerPortalOrdersApiListRequest { + organizationId?: OrganizationIDFilter; + productId?: ProductIDFilter; + productPriceType?: ProductPriceTypeFilter; + subscriptionId?: SubscriptionIDFilter; + query?: string; + page?: number; + limit?: number; + sorting?: Array; +} + +/** + * + */ +export class CustomerPortalOrdersApi extends runtime.BaseAPI { + + /** + * Get an order by ID for the authenticated customer or user. + * Get Order + */ + async getRaw(requestParameters: CustomerPortalOrdersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/orders/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get an order by ID for the authenticated customer or user. + * Get Order + */ + async get(requestParameters: CustomerPortalOrdersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Get an order\'s invoice data. + * Get Order Invoice + */ + async invoiceRaw(requestParameters: CustomerPortalOrdersApiInvoiceRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling invoice().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/orders/{id}/invoice`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get an order\'s invoice data. + * Get Order Invoice + */ + async invoice(requestParameters: CustomerPortalOrdersApiInvoiceRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.invoiceRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List orders of the authenticated customer or user. + * List Orders + */ + async listRaw(requestParameters: CustomerPortalOrdersApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['productId'] != null) { + queryParameters['product_id'] = requestParameters['productId']; + } + + if (requestParameters['productPriceType'] != null) { + queryParameters['product_price_type'] = requestParameters['productPriceType']; + } + + if (requestParameters['subscriptionId'] != null) { + queryParameters['subscription_id'] = requestParameters['subscriptionId']; + } + + if (requestParameters['query'] != null) { + queryParameters['query'] = requestParameters['query']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + if (requestParameters['sorting'] != null) { + queryParameters['sorting'] = requestParameters['sorting']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/orders/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List orders of the authenticated customer or user. + * List Orders + */ + async list(requestParameters: CustomerPortalOrdersApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalOrganizationsApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalOrganizationsApi.ts new file mode 100644 index 0000000000..d4a837c3db --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalOrganizationsApi.ts @@ -0,0 +1,67 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + HTTPValidationError, + Organization, + ResourceNotFound, +} from '../models/index'; + +export interface CustomerPortalOrganizationsApiGetRequest { + slug: string; +} + +/** + * + */ +export class CustomerPortalOrganizationsApi extends runtime.BaseAPI { + + /** + * Get a customer portal\'s organization by slug. + * Get Organization + */ + async getRaw(requestParameters: CustomerPortalOrganizationsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['slug'] == null) { + throw new runtime.RequiredError( + 'slug', + 'Required parameter "slug" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + const response = await this.request({ + path: `/v1/customer-portal/organizations/{slug}`.replace(`{${"slug"}}`, encodeURIComponent(String(requestParameters['slug']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a customer portal\'s organization by slug. + * Get Organization + */ + async get(requestParameters: CustomerPortalOrganizationsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomerPortalSubscriptionsApi.ts b/clients/packages/sdk/src/client/apis/CustomerPortalSubscriptionsApi.ts new file mode 100644 index 0000000000..e0fbd94f8d --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomerPortalSubscriptionsApi.ts @@ -0,0 +1,292 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + AlreadyCanceledSubscription, + CustomerSubscription, + CustomerSubscriptionSortProperty, + CustomerSubscriptionUpdate, + HTTPValidationError, + ListResourceCustomerSubscription, + OrganizationIDFilter, + ProductIDFilter, + ResourceNotFound, +} from '../models/index'; + +export interface CustomerPortalSubscriptionsApiCancelRequest { + id: string; +} + +export interface CustomerPortalSubscriptionsApiGetRequest { + id: string; +} + +export interface CustomerPortalSubscriptionsApiListRequest { + organizationId?: OrganizationIDFilter; + productId?: ProductIDFilter; + active?: boolean; + query?: string; + page?: number; + limit?: number; + sorting?: Array; +} + +export interface CustomerPortalSubscriptionsApiUpdateRequest { + id: string; + body: CustomerSubscriptionUpdate; +} + +/** + * + */ +export class CustomerPortalSubscriptionsApi extends runtime.BaseAPI { + + /** + * Cancel a subscription of the authenticated customer or user. + * Cancel Subscription + */ + async cancelRaw(requestParameters: CustomerPortalSubscriptionsApiCancelRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling cancel().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/subscriptions/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'DELETE', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Cancel a subscription of the authenticated customer or user. + * Cancel Subscription + */ + async cancel(requestParameters: CustomerPortalSubscriptionsApiCancelRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.cancelRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Get a subscription for the authenticated customer or user. + * Get Subscription + */ + async getRaw(requestParameters: CustomerPortalSubscriptionsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/subscriptions/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a subscription for the authenticated customer or user. + * Get Subscription + */ + async get(requestParameters: CustomerPortalSubscriptionsApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List subscriptions of the authenticated customer or user. + * List Subscriptions + */ + async listRaw(requestParameters: CustomerPortalSubscriptionsApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['productId'] != null) { + queryParameters['product_id'] = requestParameters['productId']; + } + + if (requestParameters['active'] != null) { + queryParameters['active'] = requestParameters['active']; + } + + if (requestParameters['query'] != null) { + queryParameters['query'] = requestParameters['query']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + if (requestParameters['sorting'] != null) { + queryParameters['sorting'] = requestParameters['sorting']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/subscriptions/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List subscriptions of the authenticated customer or user. + * List Subscriptions + */ + async list(requestParameters: CustomerPortalSubscriptionsApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Update a subscription of the authenticated customer or user. + * Update Subscription + */ + async updateRaw(requestParameters: CustomerPortalSubscriptionsApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling update().' + ); + } + + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling update().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("customer_session", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customer-portal/subscriptions/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'PATCH', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Update a subscription of the authenticated customer or user. + * Update Subscription + */ + async update(requestParameters: CustomerPortalSubscriptionsApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.updateRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/CustomersApi.ts b/clients/packages/sdk/src/client/apis/CustomersApi.ts new file mode 100644 index 0000000000..c8e7a41f3e --- /dev/null +++ b/clients/packages/sdk/src/client/apis/CustomersApi.ts @@ -0,0 +1,298 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * Polar API + * Read the docs at https://docs.polar.sh/api + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + + +import * as runtime from '../runtime'; +import type { + Customer, + CustomerCreate, + CustomerSortProperty, + CustomerUpdate, + HTTPValidationError, + ListResourceCustomer, + OrganizationIDFilter, + ResourceNotFound, +} from '../models/index'; + +export interface CustomersApiCreateRequest { + body: CustomerCreate; +} + +export interface CustomersApiDeleteRequest { + id: string; +} + +export interface CustomersApiGetRequest { + id: string; +} + +export interface CustomersApiListRequest { + organizationId?: OrganizationIDFilter; + query?: string; + page?: number; + limit?: number; + sorting?: Array; +} + +export interface CustomersApiUpdateRequest { + id: string; + body: CustomerUpdate; +} + +/** + * + */ +export class CustomersApi extends runtime.BaseAPI { + + /** + * Create a customer. + * Create Customer + */ + async createRaw(requestParameters: CustomersApiCreateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling create().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customers/`, + method: 'POST', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Create a customer. + * Create Customer + */ + async create(requestParameters: CustomersApiCreateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.createRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Delete a customer. Immediately cancels any active subscriptions and revokes any active benefits. + * Delete Customer + */ + async deleteRaw(requestParameters: CustomersApiDeleteRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling delete().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customers/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'DELETE', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.VoidApiResponse(response); + } + + /** + * Delete a customer. Immediately cancels any active subscriptions and revokes any active benefits. + * Delete Customer + */ + async delete(requestParameters: CustomersApiDeleteRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + await this.deleteRaw(requestParameters, initOverrides); + } + + /** + * Get a customer by ID. + * Get Customer + */ + async getRaw(requestParameters: CustomersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling get().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customers/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Get a customer by ID. + * Get Customer + */ + async get(requestParameters: CustomersApiGetRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.getRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * List customers. + * List Customers + */ + async listRaw(requestParameters: CustomersApiListRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + const queryParameters: any = {}; + + if (requestParameters['organizationId'] != null) { + queryParameters['organization_id'] = requestParameters['organizationId']; + } + + if (requestParameters['query'] != null) { + queryParameters['query'] = requestParameters['query']; + } + + if (requestParameters['page'] != null) { + queryParameters['page'] = requestParameters['page']; + } + + if (requestParameters['limit'] != null) { + queryParameters['limit'] = requestParameters['limit']; + } + + if (requestParameters['sorting'] != null) { + queryParameters['sorting'] = requestParameters['sorting']; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customers/`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * List customers. + * List Customers + */ + async list(requestParameters: CustomersApiListRequest = {}, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.listRaw(requestParameters, initOverrides); + return await response.value(); + } + + /** + * Update a customer. + * Update Customer + */ + async updateRaw(requestParameters: CustomersApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { + if (requestParameters['id'] == null) { + throw new runtime.RequiredError( + 'id', + 'Required parameter "id" was null or undefined when calling update().' + ); + } + + if (requestParameters['body'] == null) { + throw new runtime.RequiredError( + 'body', + 'Required parameter "body" was null or undefined when calling update().' + ); + } + + const queryParameters: any = {}; + + const headerParameters: runtime.HTTPHeaders = {}; + + headerParameters['Content-Type'] = 'application/json'; + + if (this.configuration && this.configuration.accessToken) { + const token = this.configuration.accessToken; + const tokenString = await token("pat", []); + + if (tokenString) { + headerParameters["Authorization"] = `Bearer ${tokenString}`; + } + } + const response = await this.request({ + path: `/v1/customers/{id}`.replace(`{${"id"}}`, encodeURIComponent(String(requestParameters['id']))), + method: 'PATCH', + headers: headerParameters, + query: queryParameters, + body: requestParameters['body'], + }, initOverrides); + + return new runtime.JSONApiResponse(response); + } + + /** + * Update a customer. + * Update Customer + */ + async update(requestParameters: CustomersApiUpdateRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { + const response = await this.updateRaw(requestParameters, initOverrides); + return await response.value(); + } + +} diff --git a/clients/packages/sdk/src/client/apis/LicenseKeysApi.ts b/clients/packages/sdk/src/client/apis/LicenseKeysApi.ts index 0cee609368..937a7590ab 100644 --- a/clients/packages/sdk/src/client/apis/LicenseKeysApi.ts +++ b/clients/packages/sdk/src/client/apis/LicenseKeysApi.ts @@ -15,7 +15,7 @@ import * as runtime from '../runtime'; import type { - BenefitIDFilter2, + BenefitIDFilter1, HTTPValidationError, LicenseKeyActivationRead, LicenseKeyRead, @@ -38,7 +38,7 @@ export interface LicenseKeysApiGetActivationRequest { export interface LicenseKeysApiListRequest { organizationId?: OrganizationIDFilter; - benefitId?: BenefitIDFilter2; + benefitId?: BenefitIDFilter1; page?: number; limit?: number; } diff --git a/clients/packages/sdk/src/client/apis/OrdersApi.ts b/clients/packages/sdk/src/client/apis/OrdersApi.ts index 9684024d72..782f26e35a 100644 --- a/clients/packages/sdk/src/client/apis/OrdersApi.ts +++ b/clients/packages/sdk/src/client/apis/OrdersApi.ts @@ -15,6 +15,7 @@ import * as runtime from '../runtime'; import type { + CustomerIDFilter, DiscountIDFilter1, HTTPValidationError, ListResourceOrder, @@ -25,7 +26,6 @@ import type { ProductIDFilter, ProductPriceTypeFilter, ResourceNotFound, - UserIDFilter, } from '../models/index'; export interface OrdersApiGetRequest { @@ -41,7 +41,7 @@ export interface OrdersApiListRequest { productId?: ProductIDFilter; productPriceType?: ProductPriceTypeFilter; discountId?: DiscountIDFilter1; - userId?: UserIDFilter; + customerId?: CustomerIDFilter; page?: number; limit?: number; sorting?: Array; @@ -161,8 +161,8 @@ export class OrdersApi extends runtime.BaseAPI { queryParameters['discount_id'] = requestParameters['discountId']; } - if (requestParameters['userId'] != null) { - queryParameters['user_id'] = requestParameters['userId']; + if (requestParameters['customerId'] != null) { + queryParameters['customer_id'] = requestParameters['customerId']; } if (requestParameters['page'] != null) { diff --git a/clients/packages/sdk/src/client/apis/ProductsApi.ts b/clients/packages/sdk/src/client/apis/ProductsApi.ts index 191549fed4..4b57660373 100644 --- a/clients/packages/sdk/src/client/apis/ProductsApi.ts +++ b/clients/packages/sdk/src/client/apis/ProductsApi.ts @@ -15,7 +15,7 @@ import * as runtime from '../runtime'; import type { - BenefitIDFilter1, + BenefitIDFilter, HTTPValidationError, ListResourceProduct, NotPermitted, @@ -41,7 +41,7 @@ export interface ProductsApiListRequest { query?: string; isArchived?: boolean; isRecurring?: boolean; - benefitId?: BenefitIDFilter1; + benefitId?: BenefitIDFilter; page?: number; limit?: number; sorting?: Array; diff --git a/clients/packages/sdk/src/client/apis/SubscriptionsApi.ts b/clients/packages/sdk/src/client/apis/SubscriptionsApi.ts index e69df34d4d..df4b77053b 100644 --- a/clients/packages/sdk/src/client/apis/SubscriptionsApi.ts +++ b/clients/packages/sdk/src/client/apis/SubscriptionsApi.ts @@ -15,6 +15,7 @@ import * as runtime from '../runtime'; import type { + CustomerIDFilter, DiscountIDFilter, HTTPValidationError, ListResourceSubscription, @@ -31,6 +32,7 @@ export interface SubscriptionsApiExportRequest { export interface SubscriptionsApiListRequest { organizationId?: OrganizationIDFilter; productId?: ProductIDFilter; + customerId?: CustomerIDFilter; discountId?: DiscountIDFilter; active?: boolean; page?: number; @@ -102,6 +104,10 @@ export class SubscriptionsApi extends runtime.BaseAPI { queryParameters['product_id'] = requestParameters['productId']; } + if (requestParameters['customerId'] != null) { + queryParameters['customer_id'] = requestParameters['customerId']; + } + if (requestParameters['discountId'] != null) { queryParameters['discount_id'] = requestParameters['discountId']; } diff --git a/clients/packages/sdk/src/client/apis/TransactionsApi.ts b/clients/packages/sdk/src/client/apis/TransactionsApi.ts index 256fca6186..79ef8c5c87 100644 --- a/clients/packages/sdk/src/client/apis/TransactionsApi.ts +++ b/clients/packages/sdk/src/client/apis/TransactionsApi.ts @@ -49,8 +49,9 @@ export interface TransactionsApiLookupTransactionRequest { export interface TransactionsApiSearchTransactionsRequest { type?: TransactionType; accountId?: string; - paymentUserId?: string; + paymentCustomerId?: string; paymentOrganizationId?: string; + paymentUserId?: string; excludePlatformFees?: boolean; page?: number; limit?: number; @@ -300,14 +301,18 @@ export class TransactionsApi extends runtime.BaseAPI { queryParameters['account_id'] = requestParameters['accountId']; } - if (requestParameters['paymentUserId'] != null) { - queryParameters['payment_user_id'] = requestParameters['paymentUserId']; + if (requestParameters['paymentCustomerId'] != null) { + queryParameters['payment_customer_id'] = requestParameters['paymentCustomerId']; } if (requestParameters['paymentOrganizationId'] != null) { queryParameters['payment_organization_id'] = requestParameters['paymentOrganizationId']; } + if (requestParameters['paymentUserId'] != null) { + queryParameters['payment_user_id'] = requestParameters['paymentUserId']; + } + if (requestParameters['excludePlatformFees'] != null) { queryParameters['exclude_platform_fees'] = requestParameters['excludePlatformFees']; } diff --git a/clients/packages/sdk/src/client/apis/index.ts b/clients/packages/sdk/src/client/apis/index.ts index 0356288947..7370326bc0 100644 --- a/clients/packages/sdk/src/client/apis/index.ts +++ b/clients/packages/sdk/src/client/apis/index.ts @@ -9,6 +9,16 @@ export * from './CheckoutLinksApi'; export * from './CheckoutsApi'; export * from './CheckoutsCustomApi'; export * from './CustomFieldsApi'; +export * from './CustomerPortalBenefitGrantsApi'; +export * from './CustomerPortalCustomerSessionApi'; +export * from './CustomerPortalCustomersApi'; +export * from './CustomerPortalDownloadablesApi'; +export * from './CustomerPortalLicenseKeysApi'; +export * from './CustomerPortalOauthAccountsApi'; +export * from './CustomerPortalOrdersApi'; +export * from './CustomerPortalOrganizationsApi'; +export * from './CustomerPortalSubscriptionsApi'; +export * from './CustomersApi'; export * from './DashboardApi'; export * from './DiscountsApi'; export * from './EmbedsApi'; @@ -38,10 +48,4 @@ export * from './StorefrontsApi'; export * from './SubscriptionsApi'; export * from './TransactionsApi'; export * from './UsersApi'; -export * from './UsersAdvertisementsApi'; -export * from './UsersBenefitsApi'; -export * from './UsersDownloadablesApi'; -export * from './UsersLicenseKeysApi'; -export * from './UsersOrdersApi'; -export * from './UsersSubscriptionsApi'; export * from './WebhooksApi'; diff --git a/clients/packages/sdk/src/client/models/index.ts b/clients/packages/sdk/src/client/models/index.ts index e991ceb335..d30d2b9564 100644 --- a/clients/packages/sdk/src/client/models/index.ts +++ b/clients/packages/sdk/src/client/models/index.ts @@ -1159,6 +1159,19 @@ export interface AuthorizeOrganization { */ avatar_url: string | null; } +/** + * + * @export + * @interface AuthorizeResponse + */ +export interface AuthorizeResponse { + /** + * + * @type {string} + * @memberof AuthorizeResponse + */ + url: string; +} /** * * @export @@ -1300,6 +1313,8 @@ export const AvailableScope = { FILESWRITE: 'files:write', SUBSCRIPTIONSREAD: 'subscriptions:read', SUBSCRIPTIONSWRITE: 'subscriptions:write', + CUSTOMERSREAD: 'customers:read', + CUSTOMERSWRITE: 'customers:write', ORDERSREAD: 'orders:read', METRICSREAD: 'metrics:read', WEBHOOKSREAD: 'webhooks:read', @@ -1311,14 +1326,8 @@ export const AvailableScope = { REPOSITORIESWRITE: 'repositories:write', ISSUESREAD: 'issues:read', ISSUESWRITE: 'issues:write', - USERBENEFITSREAD: 'user:benefits:read', - USERORDERSREAD: 'user:orders:read', - USERSUBSCRIPTIONSREAD: 'user:subscriptions:read', - USERSUBSCRIPTIONSWRITE: 'user:subscriptions:write', - USERDOWNLOADABLESREAD: 'user:downloadables:read', - USERLICENSE_KEYSREAD: 'user:license_keys:read', - USERADVERTISEMENT_CAMPAIGNSREAD: 'user:advertisement_campaigns:read', - USERADVERTISEMENT_CAMPAIGNSWRITE: 'user:advertisement_campaigns:write' + CUSTOMER_PORTALREAD: 'customer_portal:read', + CUSTOMER_PORTALWRITE: 'customer_portal:write' } as const; export type AvailableScope = typeof AvailableScope[keyof typeof AvailableScope]; @@ -1803,12 +1812,6 @@ export interface BenefitAdsSubscriber { * @memberof BenefitAdsSubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitAdsSubscriber - */ - grants: Array; /** * * @type {Organization} @@ -2071,7 +2074,7 @@ export interface BenefitCustomCreateProperties { } /** * @type BenefitCustomCreatePropertiesNote - * Private note to be shared with users who have this benefit granted. + * Private note to be shared with customers who have this benefit granted. * @export */ export type BenefitCustomCreatePropertiesNote = string; @@ -2143,12 +2146,6 @@ export interface BenefitCustomSubscriber { * @memberof BenefitCustomSubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitCustomSubscriber - */ - grants: Array; /** * * @type {Organization} @@ -2432,12 +2429,6 @@ export interface BenefitDiscordSubscriber { * @memberof BenefitDiscordSubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitDiscordSubscriber - */ - grants: Array; /** * * @type {Organization} @@ -2713,12 +2704,6 @@ export interface BenefitDownloadablesSubscriber { * @memberof BenefitDownloadablesSubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitDownloadablesSubscriber - */ - grants: Array; /** * * @type {Organization} @@ -2911,23 +2896,17 @@ export type BenefitGitHubRepositoryCreateTypeEnum = typeof BenefitGitHubReposito */ export interface BenefitGitHubRepositoryCreateProperties { /** - * - * @type {string} - * @memberof BenefitGitHubRepositoryCreateProperties - */ - repository_id?: string | null; - /** - * + * The owner of the repository. * @type {string} * @memberof BenefitGitHubRepositoryCreateProperties */ - repository_owner?: string | null; + repository_owner: string; /** - * + * The name of the repository. * @type {string} * @memberof BenefitGitHubRepositoryCreateProperties */ - repository_name?: string | null; + repository_name: string; /** * The permission level to grant. Read more about roles and their permissions on [GitHub documentation](https://docs.github.com/en/organizations/managing-user-access-to-your-organizations-repositories/managing-repository-roles/repository-roles-for-an-organization#permissions-for-each-role). * @type {string} @@ -2955,12 +2934,6 @@ export type BenefitGitHubRepositoryCreatePropertiesPermissionEnum = typeof Benef * @interface BenefitGitHubRepositoryProperties */ export interface BenefitGitHubRepositoryProperties { - /** - * - * @type {string} - * @memberof BenefitGitHubRepositoryProperties - */ - repository_id: string | null; /** * The owner of the repository. * @type {string} @@ -3048,12 +3021,6 @@ export interface BenefitGitHubRepositorySubscriber { * @memberof BenefitGitHubRepositorySubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitGitHubRepositorySubscriber - */ - grants: Array; /** * * @type {Organization} @@ -3192,102 +3159,30 @@ export interface BenefitGrant { */ order_id: string | null; /** - * The ID of the user concerned by this grant. - * @type {string} - * @memberof BenefitGrant - */ - user_id: string; - /** - * The ID of the benefit concerned by this grant. + * The ID of the customer concerned by this grant. * @type {string} * @memberof BenefitGrant */ - benefit_id: string; - /** - * - * @type {Properties} - * @memberof BenefitGrant - */ - properties: Properties; -} -/** - * - * @export - * @interface BenefitGrantAds - */ -export interface BenefitGrantAds { - /** - * Creation timestamp of the object. - * @type {string} - * @memberof BenefitGrantAds - */ - created_at: string; - /** - * - * @type {string} - * @memberof BenefitGrantAds - */ - modified_at: string | null; - /** - * The ID of the grant. - * @type {string} - * @memberof BenefitGrantAds - */ - id: string; - /** - * - * @type {string} - * @memberof BenefitGrantAds - */ - granted_at?: string | null; - /** - * Whether the benefit is granted. - * @type {boolean} - * @memberof BenefitGrantAds - */ - is_granted: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantAds - */ - revoked_at?: string | null; - /** - * Whether the benefit is revoked. - * @type {boolean} - * @memberof BenefitGrantAds - */ - is_revoked: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantAds - */ - subscription_id: string | null; + customer_id: string; /** * * @type {string} - * @memberof BenefitGrantAds - */ - order_id: string | null; - /** - * The ID of the user concerned by this grant. - * @type {string} - * @memberof BenefitGrantAds + * @memberof BenefitGrant + * @deprecated */ user_id: string; /** * The ID of the benefit concerned by this grant. * @type {string} - * @memberof BenefitGrantAds + * @memberof BenefitGrant */ benefit_id: string; /** * - * @type {BenefitGrantAdsSubscriberProperties} - * @memberof BenefitGrantAds + * @type {Properties} + * @memberof BenefitGrant */ - properties: BenefitGrantAdsSubscriberProperties; + properties: Properties; } /** * @@ -3302,19 +3197,6 @@ export interface BenefitGrantAdsProperties { */ advertisement_campaign_id: string; } -/** - * - * @export - * @interface BenefitGrantAdsSubscriberProperties - */ -export interface BenefitGrantAdsSubscriberProperties { - /** - * - * @type {string} - * @memberof BenefitGrantAdsSubscriberProperties - */ - advertisement_campaign_id?: string | null; -} /** * * @export @@ -3326,19 +3208,19 @@ export interface BenefitGrantDiscordProperties { * @type {string} * @memberof BenefitGrantDiscordProperties */ - guild_id?: string; + account_id?: string; /** * * @type {string} * @memberof BenefitGrantDiscordProperties */ - role_id?: string; + guild_id?: string; /** * * @type {string} * @memberof BenefitGrantDiscordProperties */ - account_id?: string; + role_id?: string; } /** * @@ -3364,7 +3246,7 @@ export interface BenefitGrantGitHubRepositoryProperties { * @type {string} * @memberof BenefitGrantGitHubRepositoryProperties */ - repository_id?: string | null; + account_id?: string; /** * * @type {string} @@ -3401,364 +3283,219 @@ export type BenefitGrantGitHubRepositoryPropertiesPermissionEnum = typeof Benefi /** * * @export - * @interface BenefitGrantLicenseKeys + * @interface BenefitGrantLicenseKeysProperties + */ +export interface BenefitGrantLicenseKeysProperties { + /** + * + * @type {string} + * @memberof BenefitGrantLicenseKeysProperties + */ + license_key_id?: string; + /** + * + * @type {string} + * @memberof BenefitGrantLicenseKeysProperties + */ + display_key?: string; +} +/** + * + * @export + * @interface BenefitGrantWebhook */ -export interface BenefitGrantLicenseKeys { +export interface BenefitGrantWebhook { /** * Creation timestamp of the object. * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ created_at: string; /** * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ modified_at: string | null; /** * The ID of the grant. * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ id: string; /** * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ granted_at?: string | null; /** * Whether the benefit is granted. * @type {boolean} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ is_granted: boolean; /** * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ revoked_at?: string | null; /** * Whether the benefit is revoked. * @type {boolean} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ is_revoked: boolean; /** * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ subscription_id: string | null; /** * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ order_id: string | null; /** - * The ID of the user concerned by this grant. + * The ID of the customer concerned by this grant. + * @type {string} + * @memberof BenefitGrantWebhook + */ + customer_id: string; + /** + * * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook + * @deprecated */ user_id: string; /** * The ID of the benefit concerned by this grant. * @type {string} - * @memberof BenefitGrantLicenseKeys + * @memberof BenefitGrantWebhook */ benefit_id: string; /** * - * @type {BenefitGrantLicenseKeysProperties} - * @memberof BenefitGrantLicenseKeys + * @type {Properties} + * @memberof BenefitGrantWebhook */ - properties: BenefitGrantLicenseKeysProperties; + properties: Properties; + /** + * + * @type {Benefit} + * @memberof BenefitGrantWebhook + */ + benefit: Benefit; + /** + * + * @type {PreviousProperties} + * @memberof BenefitGrantWebhook + */ + previous_properties?: PreviousProperties | null; +} +/** + * @type BenefitIDFilter + * Filter products granting specific benefit. + * @export + */ +export type BenefitIDFilter = Array | string; + +/** + * @type BenefitIDFilter1 + * Filter by benefit ID. + * @export + */ +export type BenefitIDFilter1 = Array | string; + +/** + * @type BenefitIDFilter2 + * Filter by benefit ID. + * @export + */ +export type BenefitIDFilter2 = Array | string; + +/** + * + * @export + * @interface BenefitLicenseKeyActivationProperties + */ +export interface BenefitLicenseKeyActivationProperties { + /** + * + * @type {number} + * @memberof BenefitLicenseKeyActivationProperties + */ + limit: number; + /** + * + * @type {boolean} + * @memberof BenefitLicenseKeyActivationProperties + */ + enable_customer_admin: boolean; } /** * * @export - * @interface BenefitGrantLicenseKeysProperties + * @interface BenefitLicenseKeyExpirationProperties */ -export interface BenefitGrantLicenseKeysProperties { +export interface BenefitLicenseKeyExpirationProperties { /** * - * @type {string} - * @memberof BenefitGrantLicenseKeysProperties + * @type {number} + * @memberof BenefitLicenseKeyExpirationProperties */ - license_key_id?: string; + ttl: number; /** * * @type {string} - * @memberof BenefitGrantLicenseKeysProperties + * @memberof BenefitLicenseKeyExpirationProperties */ - display_key?: string; + timeframe: BenefitLicenseKeyExpirationPropertiesTimeframeEnum; } + + +/** + * @export + */ +export const BenefitLicenseKeyExpirationPropertiesTimeframeEnum = { + YEAR: 'year', + MONTH: 'month', + DAY: 'day' +} as const; +export type BenefitLicenseKeyExpirationPropertiesTimeframeEnum = typeof BenefitLicenseKeyExpirationPropertiesTimeframeEnum[keyof typeof BenefitLicenseKeyExpirationPropertiesTimeframeEnum]; + /** * * @export - * @interface BenefitGrantSubscriber + * @interface BenefitLicenseKeys */ -export interface BenefitGrantSubscriber { +export interface BenefitLicenseKeys { /** * Creation timestamp of the object. * @type {string} - * @memberof BenefitGrantSubscriber + * @memberof BenefitLicenseKeys */ created_at: string; /** * * @type {string} - * @memberof BenefitGrantSubscriber + * @memberof BenefitLicenseKeys */ modified_at: string | null; /** - * The ID of the grant. + * The ID of the benefit. * @type {string} - * @memberof BenefitGrantSubscriber + * @memberof BenefitLicenseKeys */ id: string; /** * * @type {string} - * @memberof BenefitGrantSubscriber - */ - granted_at?: string | null; - /** - * Whether the benefit is granted. - * @type {boolean} - * @memberof BenefitGrantSubscriber - */ - is_granted: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantSubscriber - */ - revoked_at?: string | null; - /** - * Whether the benefit is revoked. - * @type {boolean} - * @memberof BenefitGrantSubscriber - */ - is_revoked: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantSubscriber - */ - subscription_id: string | null; - /** - * - * @type {string} - * @memberof BenefitGrantSubscriber - */ - order_id: string | null; - /** - * The ID of the user concerned by this grant. - * @type {string} - * @memberof BenefitGrantSubscriber - */ - user_id: string; - /** - * The ID of the benefit concerned by this grant. - * @type {string} - * @memberof BenefitGrantSubscriber - */ - benefit_id: string; -} -/** - * - * @export - * @interface BenefitGrantWebhook - */ -export interface BenefitGrantWebhook { - /** - * Creation timestamp of the object. - * @type {string} - * @memberof BenefitGrantWebhook - */ - created_at: string; - /** - * - * @type {string} - * @memberof BenefitGrantWebhook - */ - modified_at: string | null; - /** - * The ID of the grant. - * @type {string} - * @memberof BenefitGrantWebhook - */ - id: string; - /** - * - * @type {string} - * @memberof BenefitGrantWebhook - */ - granted_at?: string | null; - /** - * Whether the benefit is granted. - * @type {boolean} - * @memberof BenefitGrantWebhook - */ - is_granted: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantWebhook - */ - revoked_at?: string | null; - /** - * Whether the benefit is revoked. - * @type {boolean} - * @memberof BenefitGrantWebhook - */ - is_revoked: boolean; - /** - * - * @type {string} - * @memberof BenefitGrantWebhook - */ - subscription_id: string | null; - /** - * - * @type {string} - * @memberof BenefitGrantWebhook - */ - order_id: string | null; - /** - * The ID of the user concerned by this grant. - * @type {string} - * @memberof BenefitGrantWebhook - */ - user_id: string; - /** - * The ID of the benefit concerned by this grant. - * @type {string} - * @memberof BenefitGrantWebhook - */ - benefit_id: string; - /** - * - * @type {Properties} - * @memberof BenefitGrantWebhook - */ - properties: Properties; - /** - * - * @type {Benefit} - * @memberof BenefitGrantWebhook - */ - benefit: Benefit; - /** - * - * @type {PreviousProperties} - * @memberof BenefitGrantWebhook - */ - previous_properties?: PreviousProperties | null; -} -/** - * @type BenefitIDFilter - * Filter by given benefit ID. - * @export - */ -export type BenefitIDFilter = Array | string; - -/** - * @type BenefitIDFilter1 - * Filter products granting specific benefit. - * @export - */ -export type BenefitIDFilter1 = Array | string; - -/** - * @type BenefitIDFilter2 - * Filter by benefit ID. - * @export - */ -export type BenefitIDFilter2 = Array | string; - -/** - * - * @export - * @interface BenefitLicenseKeyActivationProperties - */ -export interface BenefitLicenseKeyActivationProperties { - /** - * - * @type {number} - * @memberof BenefitLicenseKeyActivationProperties - */ - limit: number; - /** - * - * @type {boolean} - * @memberof BenefitLicenseKeyActivationProperties - */ - enable_user_admin: boolean; -} -/** - * - * @export - * @interface BenefitLicenseKeyExpirationProperties - */ -export interface BenefitLicenseKeyExpirationProperties { - /** - * - * @type {number} - * @memberof BenefitLicenseKeyExpirationProperties - */ - ttl: number; - /** - * - * @type {string} - * @memberof BenefitLicenseKeyExpirationProperties - */ - timeframe: BenefitLicenseKeyExpirationPropertiesTimeframeEnum; -} - - -/** - * @export - */ -export const BenefitLicenseKeyExpirationPropertiesTimeframeEnum = { - YEAR: 'year', - MONTH: 'month', - DAY: 'day' -} as const; -export type BenefitLicenseKeyExpirationPropertiesTimeframeEnum = typeof BenefitLicenseKeyExpirationPropertiesTimeframeEnum[keyof typeof BenefitLicenseKeyExpirationPropertiesTimeframeEnum]; - -/** - * - * @export - * @interface BenefitLicenseKeys - */ -export interface BenefitLicenseKeys { - /** - * Creation timestamp of the object. - * @type {string} - * @memberof BenefitLicenseKeys - */ - created_at: string; - /** - * - * @type {string} - * @memberof BenefitLicenseKeys - */ - modified_at: string | null; - /** - * The ID of the benefit. - * @type {string} - * @memberof BenefitLicenseKeys - */ - id: string; - /** - * - * @type {string} - * @memberof BenefitLicenseKeys + * @memberof BenefitLicenseKeys */ type: BenefitLicenseKeysTypeEnum; /** @@ -3959,12 +3696,6 @@ export interface BenefitLicenseKeysSubscriber { * @memberof BenefitLicenseKeysSubscriber */ organization_id: string; - /** - * - * @type {Array} - * @memberof BenefitLicenseKeysSubscriber - */ - grants: Array; /** * * @type {Organization} @@ -4054,96 +3785,6 @@ export const BenefitLicenseKeysUpdateTypeEnum = { } as const; export type BenefitLicenseKeysUpdateTypeEnum = typeof BenefitLicenseKeysUpdateTypeEnum[keyof typeof BenefitLicenseKeysUpdateTypeEnum]; -/** - * - * @export - * @interface BenefitPreconditionErrorNotification - */ -export interface BenefitPreconditionErrorNotification { - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotification - */ - id: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotification - */ - created_at: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotification - */ - type: BenefitPreconditionErrorNotificationTypeEnum; - /** - * - * @type {BenefitPreconditionErrorNotificationPayload} - * @memberof BenefitPreconditionErrorNotification - */ - payload: BenefitPreconditionErrorNotificationPayload; -} - - -/** - * @export - */ -export const BenefitPreconditionErrorNotificationTypeEnum = { - BENEFIT_PRECONDITION_ERROR_NOTIFICATION: 'BenefitPreconditionErrorNotification' -} as const; -export type BenefitPreconditionErrorNotificationTypeEnum = typeof BenefitPreconditionErrorNotificationTypeEnum[keyof typeof BenefitPreconditionErrorNotificationTypeEnum]; - -/** - * - * @export - * @interface BenefitPreconditionErrorNotificationPayload - */ -export interface BenefitPreconditionErrorNotificationPayload { - /** - * - * @type {object} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - extra_context?: object; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - subject_template: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - body_template: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - scope_name: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - benefit_id: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - benefit_description: string; - /** - * - * @type {string} - * @memberof BenefitPreconditionErrorNotificationPayload - */ - organization_name: string; -} /** * @@ -4719,6 +4360,13 @@ export interface CheckoutDiscountPercentageRepeatDuration { } +/** + * @type CheckoutIDFilter + * Filter by checkout ID. + * @export + */ +export type CheckoutIDFilter = Array | string; + /** * A checkout session. * @export @@ -5227,6 +4875,12 @@ export interface CheckoutPriceCreate { * @memberof CheckoutPriceCreate */ amount?: number | null; + /** + * + * @type {string} + * @memberof CheckoutPriceCreate + */ + customer_id?: string | null; /** * Name of the customer. * @type {string} @@ -5419,6 +5073,12 @@ export interface CheckoutProductCreate { * @memberof CheckoutProductCreate */ amount?: number | null; + /** + * + * @type {string} + * @memberof CheckoutProductCreate + */ + customer_id?: string | null; /** * Name of the customer. * @type {string} @@ -5715,6 +5375,254 @@ export interface CheckoutPublic { } +/** + * Checkout session data retrieved using the client secret after confirmation. + * + * It contains a customer session token to retrieve order information + * right after the checkout. + * @export + * @interface CheckoutPublicConfirmed + */ +export interface CheckoutPublicConfirmed { + /** + * Creation timestamp of the object. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + created_at: string; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + id: string; + /** + * Key-value object storing custom field values. + * @type {object} + * @memberof CheckoutPublicConfirmed + */ + custom_field_data?: object; + /** + * + * @type {PolarEnumsPaymentProcessor} + * @memberof CheckoutPublicConfirmed + */ + payment_processor: PolarEnumsPaymentProcessor; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + status: CheckoutPublicConfirmedStatusEnum; + /** + * Client secret used to update and complete the checkout session from the client. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + client_secret: string; + /** + * URL where the customer can access the checkout session. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + url: string; + /** + * Expiration date and time of the checkout session. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + expires_at: string; + /** + * URL where the customer will be redirected after a successful payment. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + success_url: string; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + embed_origin: string | null; + /** + * Amount to pay in cents. Only useful for custom prices, it'll be ignored for fixed and free prices. + * @type {number} + * @memberof CheckoutPublicConfirmed + */ + amount: number | null; + /** + * + * @type {number} + * @memberof CheckoutPublicConfirmed + */ + tax_amount: number | null; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + currency: string | null; + /** + * + * @type {number} + * @memberof CheckoutPublicConfirmed + */ + subtotal_amount: number | null; + /** + * + * @type {number} + * @memberof CheckoutPublicConfirmed + */ + total_amount: number | null; + /** + * ID of the product to checkout. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + product_id: string; + /** + * ID of the product price to checkout. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + product_price_id: string; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + discount_id: string | null; + /** + * Whether to allow the customer to apply discount codes. If you apply a discount through `discount_id`, it'll still be applied, but the customer won't be able to change it. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + allow_discount_codes: boolean; + /** + * Whether the discount is applicable to the checkout. Typically, free and custom prices are not discountable. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + is_discount_applicable: boolean; + /** + * Whether the product price is free, regardless of discounts. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + is_free_product_price: boolean; + /** + * Whether the checkout requires payment, e.g. in case of free products or discounts that cover the total amount. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + is_payment_required: boolean; + /** + * Whether the checkout requires setting up a payment method, regardless of the amount, e.g. subscriptions that have first free cycles. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + is_payment_setup_required: boolean; + /** + * Whether the checkout requires a payment form, whether because of a payment or payment method setup. + * @type {boolean} + * @memberof CheckoutPublicConfirmed + */ + is_payment_form_required: boolean; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_id: string | null; + /** + * Name of the customer. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_name: string | null; + /** + * Email address of the customer. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_email: string | null; + /** + * IP address of the customer. Used to detect tax location. + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_ip_address: string | null; + /** + * + * @type {Address} + * @memberof CheckoutPublicConfirmed + */ + customer_billing_address: Address | null; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_tax_id: string | null; + /** + * + * @type {object} + * @memberof CheckoutPublicConfirmed + */ + payment_processor_metadata: object; + /** + * + * @type {CheckoutProduct} + * @memberof CheckoutPublicConfirmed + */ + product: CheckoutProduct; + /** + * + * @type {ProductPrice} + * @memberof CheckoutPublicConfirmed + */ + product_price: ProductPrice; + /** + * + * @type {CheckoutDiscount} + * @memberof CheckoutPublicConfirmed + */ + discount: CheckoutDiscount | null; + /** + * + * @type {Organization} + * @memberof CheckoutPublicConfirmed + */ + organization: Organization; + /** + * + * @type {Array} + * @memberof CheckoutPublicConfirmed + */ + attached_custom_fields: Array; + /** + * + * @type {string} + * @memberof CheckoutPublicConfirmed + */ + customer_session_token: string; +} + + +/** + * @export + */ +export const CheckoutPublicConfirmedStatusEnum = { + CONFIRMED: 'confirmed' +} as const; +export type CheckoutPublicConfirmedStatusEnum = typeof CheckoutPublicConfirmedStatusEnum[keyof typeof CheckoutPublicConfirmedStatusEnum]; + /** * @@ -7138,11619 +7046,13011 @@ export const CustomFieldUpdateTextTypeEnum = { export type CustomFieldUpdateTextTypeEnum = typeof CustomFieldUpdateTextTypeEnum[keyof typeof CustomFieldUpdateTextTypeEnum]; /** - * + * A customer in an organization. * @export * @interface Customer */ export interface Customer { /** - * + * Creation timestamp of the object. * @type {string} * @memberof Customer */ - public_name: string; + created_at: string; /** * * @type {string} * @memberof Customer */ - github_username: string | null; + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof Customer + */ + id: string; + /** + * + * @type {{ [key: string]: MetadataValue; }} + * @memberof Customer + */ + metadata: { [key: string]: MetadataValue; }; /** * * @type {string} * @memberof Customer */ - avatar_url: string | null; -} -/** - * - * @export - * @interface Customers - */ -export interface Customers { + email: string; /** * - * @type {number} - * @memberof Customers + * @type {boolean} + * @memberof Customer */ - total: number; + email_verified: boolean; /** * - * @type {Array} - * @memberof Customers + * @type {string} + * @memberof Customer */ - customers: Array; -} -/** - * - * @export - * @interface DiscordGuild - */ -export interface DiscordGuild { + name: string | null; + /** + * + * @type {Address} + * @memberof Customer + */ + billing_address: Address | null; + /** + * + * @type {Array} + * @memberof Customer + */ + tax_id: Array | null; /** * * @type {string} - * @memberof DiscordGuild + * @memberof Customer */ - name: string; + organization_id: string; /** * - * @type {Array} - * @memberof DiscordGuild + * @type {string} + * @memberof Customer */ - roles: Array; + readonly avatar_url: string; } +/** + * @type CustomerBenefitGrant + * @export + */ +export type CustomerBenefitGrant = CustomerBenefitGrantAds | CustomerBenefitGrantCustom | CustomerBenefitGrantDiscord | CustomerBenefitGrantDownloadables | CustomerBenefitGrantGitHubRepository | CustomerBenefitGrantLicenseKeys; + /** * * @export - * @interface DiscordGuildRole + * @interface CustomerBenefitGrantAds */ -export interface DiscordGuildRole { +export interface CustomerBenefitGrantAds { + /** + * Creation timestamp of the object. + * @type {string} + * @memberof CustomerBenefitGrantAds + */ + created_at: string; /** * * @type {string} - * @memberof DiscordGuildRole + * @memberof CustomerBenefitGrantAds + */ + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof CustomerBenefitGrantAds */ id: string; /** * * @type {string} - * @memberof DiscordGuildRole + * @memberof CustomerBenefitGrantAds */ - name: string; + granted_at: string | null; /** * - * @type {number} - * @memberof DiscordGuildRole + * @type {string} + * @memberof CustomerBenefitGrantAds */ - position: number; + revoked_at: string | null; /** * - * @type {boolean} - * @memberof DiscordGuildRole + * @type {string} + * @memberof CustomerBenefitGrantAds */ - is_polar_bot: boolean; + customer_id: string; /** * * @type {string} - * @memberof DiscordGuildRole + * @memberof CustomerBenefitGrantAds */ - color: string; + benefit_id: string; + /** + * + * @type {string} + * @memberof CustomerBenefitGrantAds + */ + subscription_id: string | null; + /** + * + * @type {string} + * @memberof CustomerBenefitGrantAds + */ + order_id: string | null; + /** + * + * @type {boolean} + * @memberof CustomerBenefitGrantAds + */ + is_granted: boolean; + /** + * + * @type {boolean} + * @memberof CustomerBenefitGrantAds + */ + is_revoked: boolean; + /** + * + * @type {BenefitAdsSubscriber} + * @memberof CustomerBenefitGrantAds + */ + benefit: BenefitAdsSubscriber; + /** + * + * @type {BenefitGrantAdsProperties} + * @memberof CustomerBenefitGrantAds + */ + properties: BenefitGrantAdsProperties; } /** - * @type Discount - * - * @export - */ -export type Discount = DiscountFixedOnceForeverDuration | DiscountFixedRepeatDuration | DiscountPercentageOnceForeverDuration | DiscountPercentageRepeatDuration; -/** - * @type DiscountCreate * * @export + * @interface CustomerBenefitGrantAdsUpdate */ -export type DiscountCreate = DiscountFixedOnceForeverDurationCreate | DiscountFixedRepeatDurationCreate | DiscountPercentageOnceForeverDurationCreate | DiscountPercentageRepeatDurationCreate; +export interface CustomerBenefitGrantAdsUpdate { + /** + * + * @type {string} + * @memberof CustomerBenefitGrantAdsUpdate + */ + benefit_type: CustomerBenefitGrantAdsUpdateBenefitTypeEnum; +} + /** - * * @export */ -export const DiscountDuration = { - ONCE: 'once', - FOREVER: 'forever', - REPEATING: 'repeating' +export const CustomerBenefitGrantAdsUpdateBenefitTypeEnum = { + ADS: 'ads' } as const; -export type DiscountDuration = typeof DiscountDuration[keyof typeof DiscountDuration]; +export type CustomerBenefitGrantAdsUpdateBenefitTypeEnum = typeof CustomerBenefitGrantAdsUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantAdsUpdateBenefitTypeEnum]; /** - * Schema for a fixed amount discount that is applied once or forever. + * * @export - * @interface DiscountFixedOnceForeverDuration + * @interface CustomerBenefitGrantCustom */ -export interface DiscountFixedOnceForeverDuration { - /** - * - * @type {DiscountDuration} - * @memberof DiscountFixedOnceForeverDuration - */ - duration: DiscountDuration; - /** - * - * @type {DiscountType} - * @memberof DiscountFixedOnceForeverDuration - */ - type: DiscountType; - /** - * - * @type {number} - * @memberof DiscountFixedOnceForeverDuration - */ - amount: number; - /** - * - * @type {string} - * @memberof DiscountFixedOnceForeverDuration - */ - currency: string; +export interface CustomerBenefitGrantCustom { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ created_at: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountFixedOnceForeverDuration - */ - metadata: { [key: string]: MetadataValue; }; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ - name: string; + granted_at: string | null; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ - code: string | null; + revoked_at: string | null; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ - starts_at: string | null; + customer_id: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ - ends_at: string | null; + benefit_id: string; /** * - * @type {number} - * @memberof DiscountFixedOnceForeverDuration - */ - max_redemptions: number | null; - /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountFixedOnceForeverDuration + * @type {string} + * @memberof CustomerBenefitGrantCustom */ - redemptions_count: number; + subscription_id: string | null; /** - * The organization ID. + * * @type {string} - * @memberof DiscountFixedOnceForeverDuration + * @memberof CustomerBenefitGrantCustom */ - organization_id: string; + order_id: string | null; /** * - * @type {Array} - * @memberof DiscountFixedOnceForeverDuration + * @type {boolean} + * @memberof CustomerBenefitGrantCustom */ - products: Array; -} - - -/** - * - * @export - * @interface DiscountFixedOnceForeverDurationBase - */ -export interface DiscountFixedOnceForeverDurationBase { + is_granted: boolean; /** * - * @type {DiscountDuration} - * @memberof DiscountFixedOnceForeverDurationBase + * @type {boolean} + * @memberof CustomerBenefitGrantCustom */ - duration: DiscountDuration; + is_revoked: boolean; /** * - * @type {DiscountType} - * @memberof DiscountFixedOnceForeverDurationBase + * @type {BenefitCustomSubscriber} + * @memberof CustomerBenefitGrantCustom */ - type: DiscountType; + benefit: BenefitCustomSubscriber; /** * - * @type {number} - * @memberof DiscountFixedOnceForeverDurationBase + * @type {object} + * @memberof CustomerBenefitGrantCustom */ - amount: number; + properties: object; +} +/** + * + * @export + * @interface CustomerBenefitGrantCustomUpdate + */ +export interface CustomerBenefitGrantCustomUpdate { /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantCustomUpdate */ - currency: string; + benefit_type: CustomerBenefitGrantCustomUpdateBenefitTypeEnum; +} + + +/** + * @export + */ +export const CustomerBenefitGrantCustomUpdateBenefitTypeEnum = { + CUSTOM: 'custom' +} as const; +export type CustomerBenefitGrantCustomUpdateBenefitTypeEnum = typeof CustomerBenefitGrantCustomUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantCustomUpdateBenefitTypeEnum]; + +/** + * + * @export + * @interface CustomerBenefitGrantDiscord + */ +export interface CustomerBenefitGrantDiscord { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ created_at: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountFixedOnceForeverDurationBase + * @type {string} + * @memberof CustomerBenefitGrantDiscord */ - metadata: { [key: string]: MetadataValue; }; + granted_at: string | null; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ - name: string; + revoked_at: string | null; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ - code: string | null; + customer_id: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ - starts_at: string | null; + benefit_id: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * @memberof CustomerBenefitGrantDiscord */ - ends_at: string | null; + subscription_id: string | null; /** * - * @type {number} - * @memberof DiscountFixedOnceForeverDurationBase + * @type {string} + * @memberof CustomerBenefitGrantDiscord */ - max_redemptions: number | null; + order_id: string | null; /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountFixedOnceForeverDurationBase + * + * @type {boolean} + * @memberof CustomerBenefitGrantDiscord */ - redemptions_count: number; + is_granted: boolean; /** - * The organization ID. - * @type {string} - * @memberof DiscountFixedOnceForeverDurationBase + * + * @type {boolean} + * @memberof CustomerBenefitGrantDiscord */ - organization_id: string; + is_revoked: boolean; + /** + * + * @type {BenefitDiscordSubscriber} + * @memberof CustomerBenefitGrantDiscord + */ + benefit: BenefitDiscordSubscriber; + /** + * + * @type {BenefitGrantDiscordProperties} + * @memberof CustomerBenefitGrantDiscord + */ + properties: BenefitGrantDiscordProperties; } - - /** - * Schema to create a fixed amount discount that is applied once or forever. + * * @export - * @interface DiscountFixedOnceForeverDurationCreate + * @interface CustomerBenefitGrantDiscordPropertiesUpdate */ -export interface DiscountFixedOnceForeverDurationCreate { +export interface CustomerBenefitGrantDiscordPropertiesUpdate { /** * - * @type {DiscountDuration} - * @memberof DiscountFixedOnceForeverDurationCreate + * @type {string} + * @memberof CustomerBenefitGrantDiscordPropertiesUpdate */ - duration: DiscountDuration; + account_id: string; +} +/** + * + * @export + * @interface CustomerBenefitGrantDiscordUpdate + */ +export interface CustomerBenefitGrantDiscordUpdate { /** * - * @type {DiscountType} - * @memberof DiscountFixedOnceForeverDurationCreate + * @type {string} + * @memberof CustomerBenefitGrantDiscordUpdate */ - type: DiscountType; + benefit_type: CustomerBenefitGrantDiscordUpdateBenefitTypeEnum; /** - * Fixed amount to discount from the invoice total. - * @type {number} - * @memberof DiscountFixedOnceForeverDurationCreate + * + * @type {CustomerBenefitGrantDiscordPropertiesUpdate} + * @memberof CustomerBenefitGrantDiscordUpdate */ - amount: number; + properties: CustomerBenefitGrantDiscordPropertiesUpdate; +} + + +/** + * @export + */ +export const CustomerBenefitGrantDiscordUpdateBenefitTypeEnum = { + DISCORD: 'discord' +} as const; +export type CustomerBenefitGrantDiscordUpdateBenefitTypeEnum = typeof CustomerBenefitGrantDiscordUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantDiscordUpdateBenefitTypeEnum]; + +/** + * + * @export + * @interface CustomerBenefitGrantDownloadables + */ +export interface CustomerBenefitGrantDownloadables { /** - * The currency. Currently, only `usd` is supported. + * Creation timestamp of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - currency?: string; + created_at: string; /** - * Key-value object allowing you to store additional information. - * - * The key must be a string with a maximum length of **40 characters**. - * The value must be either: - * - * * A string with a maximum length of **500 characters** - * * An integer - * * A boolean * - * You can store up to **50 key-value pairs**. - * @type {{ [key: string]: MetadataValue1; }} - * @memberof DiscountFixedOnceForeverDurationCreate + * @type {string} + * @memberof CustomerBenefitGrantDownloadables */ - metadata?: { [key: string]: MetadataValue1; }; + modified_at: string | null; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * The ID of the object. * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - name: string; + id: string; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - code?: string | null; + granted_at: string | null; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - starts_at?: string | null; + revoked_at: string | null; /** * * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - ends_at?: string | null; + customer_id: string; /** * - * @type {number} - * @memberof DiscountFixedOnceForeverDurationCreate + * @type {string} + * @memberof CustomerBenefitGrantDownloadables */ - max_redemptions?: number | null; + benefit_id: string; /** - * List of product IDs the discount can be applied to. - * @type {Array} - * @memberof DiscountFixedOnceForeverDurationCreate + * + * @type {string} + * @memberof CustomerBenefitGrantDownloadables */ - products?: Array | null; + subscription_id: string | null; /** - * The organization ID. + * * @type {string} - * @memberof DiscountFixedOnceForeverDurationCreate + * @memberof CustomerBenefitGrantDownloadables */ - organization_id?: string | null; -} - - -/** - * Schema for a fixed amount discount that is applied on every invoice - * for a certain number of months. - * @export - * @interface DiscountFixedRepeatDuration - */ -export interface DiscountFixedRepeatDuration { + order_id: string | null; /** * - * @type {DiscountDuration} - * @memberof DiscountFixedRepeatDuration + * @type {boolean} + * @memberof CustomerBenefitGrantDownloadables */ - duration: DiscountDuration; + is_granted: boolean; /** * - * @type {number} - * @memberof DiscountFixedRepeatDuration + * @type {boolean} + * @memberof CustomerBenefitGrantDownloadables */ - duration_in_months: number; + is_revoked: boolean; /** * - * @type {DiscountType} - * @memberof DiscountFixedRepeatDuration + * @type {BenefitDownloadablesSubscriber} + * @memberof CustomerBenefitGrantDownloadables */ - type: DiscountType; + benefit: BenefitDownloadablesSubscriber; /** * - * @type {number} - * @memberof DiscountFixedRepeatDuration + * @type {BenefitGrantDownloadablesProperties} + * @memberof CustomerBenefitGrantDownloadables */ - amount: number; + properties: BenefitGrantDownloadablesProperties; +} +/** + * + * @export + * @interface CustomerBenefitGrantDownloadablesUpdate + */ +export interface CustomerBenefitGrantDownloadablesUpdate { /** * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantDownloadablesUpdate */ - currency: string; + benefit_type: CustomerBenefitGrantDownloadablesUpdateBenefitTypeEnum; +} + + +/** + * @export + */ +export const CustomerBenefitGrantDownloadablesUpdateBenefitTypeEnum = { + DOWNLOADABLES: 'downloadables' +} as const; +export type CustomerBenefitGrantDownloadablesUpdateBenefitTypeEnum = typeof CustomerBenefitGrantDownloadablesUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantDownloadablesUpdateBenefitTypeEnum]; + +/** + * + * @export + * @interface CustomerBenefitGrantGitHubRepository + */ +export interface CustomerBenefitGrantGitHubRepository { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ created_at: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountFixedRepeatDuration + * @type {string} + * @memberof CustomerBenefitGrantGitHubRepository */ - metadata: { [key: string]: MetadataValue; }; + granted_at: string | null; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ - name: string; + revoked_at: string | null; /** * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ - code: string | null; + customer_id: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ - starts_at: string | null; + benefit_id: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDuration + * @memberof CustomerBenefitGrantGitHubRepository */ - ends_at: string | null; + subscription_id: string | null; /** * - * @type {number} - * @memberof DiscountFixedRepeatDuration + * @type {string} + * @memberof CustomerBenefitGrantGitHubRepository */ - max_redemptions: number | null; + order_id: string | null; /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountFixedRepeatDuration + * + * @type {boolean} + * @memberof CustomerBenefitGrantGitHubRepository */ - redemptions_count: number; + is_granted: boolean; /** - * The organization ID. - * @type {string} - * @memberof DiscountFixedRepeatDuration + * + * @type {boolean} + * @memberof CustomerBenefitGrantGitHubRepository */ - organization_id: string; + is_revoked: boolean; /** * - * @type {Array} - * @memberof DiscountFixedRepeatDuration + * @type {BenefitGitHubRepositorySubscriber} + * @memberof CustomerBenefitGrantGitHubRepository */ - products: Array; + benefit: BenefitGitHubRepositorySubscriber; + /** + * + * @type {BenefitGrantGitHubRepositoryProperties} + * @memberof CustomerBenefitGrantGitHubRepository + */ + properties: BenefitGrantGitHubRepositoryProperties; } - - /** * * @export - * @interface DiscountFixedRepeatDurationBase + * @interface CustomerBenefitGrantGitHubRepositoryPropertiesUpdate */ -export interface DiscountFixedRepeatDurationBase { - /** - * - * @type {DiscountDuration} - * @memberof DiscountFixedRepeatDurationBase - */ - duration: DiscountDuration; - /** - * - * @type {number} - * @memberof DiscountFixedRepeatDurationBase - */ - duration_in_months: number; +export interface CustomerBenefitGrantGitHubRepositoryPropertiesUpdate { /** * - * @type {DiscountType} - * @memberof DiscountFixedRepeatDurationBase + * @type {string} + * @memberof CustomerBenefitGrantGitHubRepositoryPropertiesUpdate */ - type: DiscountType; + account_id: string; +} +/** + * + * @export + * @interface CustomerBenefitGrantGitHubRepositoryUpdate + */ +export interface CustomerBenefitGrantGitHubRepositoryUpdate { /** * - * @type {number} - * @memberof DiscountFixedRepeatDurationBase + * @type {string} + * @memberof CustomerBenefitGrantGitHubRepositoryUpdate */ - amount: number; + benefit_type: CustomerBenefitGrantGitHubRepositoryUpdateBenefitTypeEnum; /** * - * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @type {CustomerBenefitGrantGitHubRepositoryPropertiesUpdate} + * @memberof CustomerBenefitGrantGitHubRepositoryUpdate */ - currency: string; + properties: CustomerBenefitGrantGitHubRepositoryPropertiesUpdate; +} + + +/** + * @export + */ +export const CustomerBenefitGrantGitHubRepositoryUpdateBenefitTypeEnum = { + GITHUB_REPOSITORY: 'github_repository' +} as const; +export type CustomerBenefitGrantGitHubRepositoryUpdateBenefitTypeEnum = typeof CustomerBenefitGrantGitHubRepositoryUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantGitHubRepositoryUpdateBenefitTypeEnum]; + +/** + * + * @export + * @interface CustomerBenefitGrantLicenseKeys + */ +export interface CustomerBenefitGrantLicenseKeys { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ created_at: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountFixedRepeatDurationBase + * @type {string} + * @memberof CustomerBenefitGrantLicenseKeys */ - metadata: { [key: string]: MetadataValue; }; + granted_at: string | null; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ - name: string; + revoked_at: string | null; /** * * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ - code: string | null; + customer_id: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ - starts_at: string | null; + benefit_id: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ - ends_at: string | null; + subscription_id: string | null; /** * - * @type {number} - * @memberof DiscountFixedRepeatDurationBase - */ - max_redemptions: number | null; - /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountFixedRepeatDurationBase - */ - redemptions_count: number; - /** - * The organization ID. * @type {string} - * @memberof DiscountFixedRepeatDurationBase + * @memberof CustomerBenefitGrantLicenseKeys */ - organization_id: string; -} - - -/** - * Schema to create a fixed amount discount that is applied on every invoice - * for a certain number of months. - * @export - * @interface DiscountFixedRepeatDurationCreate - */ -export interface DiscountFixedRepeatDurationCreate { + order_id: string | null; /** * - * @type {DiscountDuration} - * @memberof DiscountFixedRepeatDurationCreate + * @type {boolean} + * @memberof CustomerBenefitGrantLicenseKeys */ - duration: DiscountDuration; + is_granted: boolean; /** - * Number of months the discount should be applied. * - * For this to work on yearly pricing, you should multiply this by 12. - * For example, to apply the discount for 2 years, set this to 24. - * @type {number} - * @memberof DiscountFixedRepeatDurationCreate + * @type {boolean} + * @memberof CustomerBenefitGrantLicenseKeys */ - duration_in_months: number; + is_revoked: boolean; /** * - * @type {DiscountType} - * @memberof DiscountFixedRepeatDurationCreate - */ - type: DiscountType; - /** - * Fixed amount to discount from the invoice total. - * @type {number} - * @memberof DiscountFixedRepeatDurationCreate - */ - amount: number; - /** - * The currency. Currently, only `usd` is supported. - * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @type {BenefitLicenseKeysSubscriber} + * @memberof CustomerBenefitGrantLicenseKeys */ - currency?: string; + benefit: BenefitLicenseKeysSubscriber; /** - * Key-value object allowing you to store additional information. - * - * The key must be a string with a maximum length of **40 characters**. - * The value must be either: - * - * * A string with a maximum length of **500 characters** - * * An integer - * * A boolean * - * You can store up to **50 key-value pairs**. - * @type {{ [key: string]: MetadataValue1; }} - * @memberof DiscountFixedRepeatDurationCreate - */ - metadata?: { [key: string]: MetadataValue1; }; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. - * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @type {BenefitGrantLicenseKeysProperties} + * @memberof CustomerBenefitGrantLicenseKeys */ - name: string; + properties: BenefitGrantLicenseKeysProperties; +} +/** + * + * @export + * @interface CustomerBenefitGrantLicenseKeysUpdate + */ +export interface CustomerBenefitGrantLicenseKeysUpdate { /** * * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @memberof CustomerBenefitGrantLicenseKeysUpdate */ - code?: string | null; + benefit_type: CustomerBenefitGrantLicenseKeysUpdateBenefitTypeEnum; +} + + +/** + * @export + */ +export const CustomerBenefitGrantLicenseKeysUpdateBenefitTypeEnum = { + LICENSE_KEYS: 'license_keys' +} as const; +export type CustomerBenefitGrantLicenseKeysUpdateBenefitTypeEnum = typeof CustomerBenefitGrantLicenseKeysUpdateBenefitTypeEnum[keyof typeof CustomerBenefitGrantLicenseKeysUpdateBenefitTypeEnum]; + + +/** + * + * @export + */ +export const CustomerBenefitGrantSortProperty = { + GRANTED_AT: 'granted_at', + GRANTED_AT2: '-granted_at', + TYPE: 'type', + TYPE2: '-type', + ORGANIZATION: 'organization', + ORGANIZATION2: '-organization' +} as const; +export type CustomerBenefitGrantSortProperty = typeof CustomerBenefitGrantSortProperty[keyof typeof CustomerBenefitGrantSortProperty]; + +/** + * @type CustomerBenefitGrantUpdate + * + * @export + */ +export type CustomerBenefitGrantUpdate = { benefit_type: 'ads' } & CustomerBenefitGrantAdsUpdate | { benefit_type: 'custom' } & CustomerBenefitGrantCustomUpdate | { benefit_type: 'discord' } & CustomerBenefitGrantDiscordUpdate | { benefit_type: 'downloadables' } & CustomerBenefitGrantDownloadablesUpdate | { benefit_type: 'github_repository' } & CustomerBenefitGrantGitHubRepositoryUpdate | { benefit_type: 'license_keys' } & CustomerBenefitGrantLicenseKeysUpdate; +/** + * + * @export + * @interface CustomerCreate + */ +export interface CustomerCreate { /** * * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @memberof CustomerCreate */ - starts_at?: string | null; + email: string; /** * * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @memberof CustomerCreate */ - ends_at?: string | null; + name?: string | null; /** * - * @type {number} - * @memberof DiscountFixedRepeatDurationCreate + * @type {Address} + * @memberof CustomerCreate */ - max_redemptions?: number | null; + billing_address?: Address | null; /** - * List of product IDs the discount can be applied to. + * * @type {Array} - * @memberof DiscountFixedRepeatDurationCreate + * @memberof CustomerCreate */ - products?: Array | null; + tax_id?: Array | null; /** * The organization ID. * @type {string} - * @memberof DiscountFixedRepeatDurationCreate + * @memberof CustomerCreate */ organization_id?: string | null; } - +/** + * @type CustomerIDFilter + * Filter by customer ID. + * @export + */ +export type CustomerIDFilter = Array | string; /** - * @type DiscountIDFilter - * Filter by discount ID. + * @type CustomerIDFilter1 + * Filter by customer. * @export */ -export type DiscountIDFilter = Array | string; +export type CustomerIDFilter1 = Array | string; + /** - * @type DiscountIDFilter1 - * Filter by discount ID. + * * @export */ -export type DiscountIDFilter1 = Array | string; +export const CustomerOAuthPlatform = { + GITHUB: 'github', + DISCORD: 'discord' +} as const; +export type CustomerOAuthPlatform = typeof CustomerOAuthPlatform[keyof typeof CustomerOAuthPlatform]; /** - * Schema for a percentage discount that is applied once or forever. + * * @export - * @interface DiscountPercentageOnceForeverDuration + * @interface CustomerOrder */ -export interface DiscountPercentageOnceForeverDuration { - /** - * - * @type {DiscountDuration} - * @memberof DiscountPercentageOnceForeverDuration - */ - duration: DiscountDuration; - /** - * - * @type {DiscountType} - * @memberof DiscountPercentageOnceForeverDuration - */ - type: DiscountType; - /** - * - * @type {number} - * @memberof DiscountPercentageOnceForeverDuration - */ - basis_points: number; +export interface CustomerOrder { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ created_at: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ modified_at: string | null; /** - * The ID of the object. + * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountPercentageOnceForeverDuration + * @type {number} + * @memberof CustomerOrder */ - metadata: { [key: string]: MetadataValue; }; + amount: number; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. - * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * + * @type {number} + * @memberof CustomerOrder */ - name: string; + tax_amount: number; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ - code: string | null; + currency: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ - starts_at: string | null; + customer_id: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder */ - ends_at: string | null; + product_id: string; /** * - * @type {number} - * @memberof DiscountPercentageOnceForeverDuration + * @type {string} + * @memberof CustomerOrder */ - max_redemptions: number | null; + product_price_id: string; /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountPercentageOnceForeverDuration + * + * @type {string} + * @memberof CustomerOrder */ - redemptions_count: number; + subscription_id: string | null; /** - * The organization ID. + * * @type {string} - * @memberof DiscountPercentageOnceForeverDuration + * @memberof CustomerOrder + * @deprecated */ - organization_id: string; + user_id: string; /** * - * @type {Array} - * @memberof DiscountPercentageOnceForeverDuration + * @type {CustomerOrderProduct} + * @memberof CustomerOrder */ - products: Array; -} - - -/** - * - * @export - * @interface DiscountPercentageOnceForeverDurationBase - */ -export interface DiscountPercentageOnceForeverDurationBase { + product: CustomerOrderProduct; /** * - * @type {DiscountDuration} - * @memberof DiscountPercentageOnceForeverDurationBase + * @type {ProductPrice} + * @memberof CustomerOrder */ - duration: DiscountDuration; + product_price: ProductPrice; /** * - * @type {DiscountType} - * @memberof DiscountPercentageOnceForeverDurationBase + * @type {CustomerOrderSubscription} + * @memberof CustomerOrder */ - type: DiscountType; + subscription: CustomerOrderSubscription | null; +} +/** + * Order's invoice data. + * @export + * @interface CustomerOrderInvoice + */ +export interface CustomerOrderInvoice { /** - * - * @type {number} - * @memberof DiscountPercentageOnceForeverDurationBase + * The URL to the invoice. + * @type {string} + * @memberof CustomerOrderInvoice */ - basis_points: number; + url: string; +} +/** + * + * @export + * @interface CustomerOrderProduct + */ +export interface CustomerOrderProduct { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ created_at: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ modified_at: string | null; /** - * The ID of the object. + * The ID of the product. * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ id: string; /** - * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountPercentageOnceForeverDurationBase - */ - metadata: { [key: string]: MetadataValue; }; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * The name of the product. * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ name: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ - code: string | null; + description: string | null; /** - * - * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * Whether the product is a subscription tier. + * @type {boolean} + * @memberof CustomerOrderProduct */ - starts_at: string | null; + is_recurring: boolean; /** - * + * Whether the product is archived and no longer available. + * @type {boolean} + * @memberof CustomerOrderProduct + */ + is_archived: boolean; + /** + * The ID of the organization owning the product. * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * @memberof CustomerOrderProduct */ - ends_at: string | null; + organization_id: string; /** - * - * @type {number} - * @memberof DiscountPercentageOnceForeverDurationBase + * List of prices for this product. + * @type {Array} + * @memberof CustomerOrderProduct */ - max_redemptions: number | null; + prices: Array; /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountPercentageOnceForeverDurationBase + * List of benefits granted by the product. + * @type {Array} + * @memberof CustomerOrderProduct */ - redemptions_count: number; + benefits: Array; /** - * The organization ID. - * @type {string} - * @memberof DiscountPercentageOnceForeverDurationBase + * List of medias associated to the product. + * @type {Array} + * @memberof CustomerOrderProduct */ - organization_id: string; + medias: Array; + /** + * + * @type {Organization} + * @memberof CustomerOrderProduct + */ + organization: Organization; } +/** + * + * @export + */ +export const CustomerOrderSortProperty = { + CREATED_AT: 'created_at', + CREATED_AT2: '-created_at', + AMOUNT: 'amount', + AMOUNT2: '-amount', + ORGANIZATION: 'organization', + ORGANIZATION2: '-organization', + PRODUCT: 'product', + PRODUCT2: '-product', + SUBSCRIPTION: 'subscription', + SUBSCRIPTION2: '-subscription' +} as const; +export type CustomerOrderSortProperty = typeof CustomerOrderSortProperty[keyof typeof CustomerOrderSortProperty]; /** - * Schema to create a percentage discount that is applied once or forever. + * * @export - * @interface DiscountPercentageOnceForeverDurationCreate + * @interface CustomerOrderSubscription */ -export interface DiscountPercentageOnceForeverDurationCreate { +export interface CustomerOrderSubscription { /** - * - * @type {DiscountDuration} - * @memberof DiscountPercentageOnceForeverDurationCreate + * Creation timestamp of the object. + * @type {string} + * @memberof CustomerOrderSubscription */ - duration: DiscountDuration; + created_at: string; /** * - * @type {DiscountType} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @type {string} + * @memberof CustomerOrderSubscription */ - type: DiscountType; + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof CustomerOrderSubscription + */ + id: string; /** - * Discount percentage in basis points. * - * A basis point is 1/100th of a percent. - * For example, to create a 25.5% discount, set this to 2550. * @type {number} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @memberof CustomerOrderSubscription */ - basis_points: number; + amount: number | null; /** - * Key-value object allowing you to store additional information. - * - * The key must be a string with a maximum length of **40 characters**. - * The value must be either: * - * * A string with a maximum length of **500 characters** - * * An integer - * * A boolean + * @type {string} + * @memberof CustomerOrderSubscription + */ + currency: string | null; + /** * - * You can store up to **50 key-value pairs**. - * @type {{ [key: string]: MetadataValue1; }} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @type {SubscriptionRecurringInterval} + * @memberof CustomerOrderSubscription */ - metadata?: { [key: string]: MetadataValue1; }; + recurring_interval: SubscriptionRecurringInterval; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. - * @type {string} - * @memberof DiscountPercentageOnceForeverDurationCreate + * + * @type {SubscriptionStatus} + * @memberof CustomerOrderSubscription */ - name: string; + status: SubscriptionStatus; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @memberof CustomerOrderSubscription */ - code?: string | null; + current_period_start: string; /** * * @type {string} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @memberof CustomerOrderSubscription */ - starts_at?: string | null; + current_period_end: string | null; /** * - * @type {string} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @type {boolean} + * @memberof CustomerOrderSubscription */ - ends_at?: string | null; + cancel_at_period_end: boolean; /** * - * @type {number} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @type {string} + * @memberof CustomerOrderSubscription */ - max_redemptions?: number | null; + started_at: string | null; /** - * List of product IDs the discount can be applied to. - * @type {Array} - * @memberof DiscountPercentageOnceForeverDurationCreate + * + * @type {string} + * @memberof CustomerOrderSubscription */ - products?: Array | null; + ended_at: string | null; /** - * The organization ID. + * * @type {string} - * @memberof DiscountPercentageOnceForeverDurationCreate + * @memberof CustomerOrderSubscription */ - organization_id?: string | null; -} - - -/** - * Schema for a percentage discount that is applied on every invoice - * for a certain number of months. - * @export - * @interface DiscountPercentageRepeatDuration - */ -export interface DiscountPercentageRepeatDuration { + customer_id: string; /** * - * @type {DiscountDuration} - * @memberof DiscountPercentageRepeatDuration + * @type {string} + * @memberof CustomerOrderSubscription */ - duration: DiscountDuration; + product_id: string; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDuration + * @type {string} + * @memberof CustomerOrderSubscription */ - duration_in_months: number; + price_id: string; /** * - * @type {DiscountType} - * @memberof DiscountPercentageRepeatDuration + * @type {string} + * @memberof CustomerOrderSubscription */ - type: DiscountType; + discount_id: string | null; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDuration + * @type {string} + * @memberof CustomerOrderSubscription */ - basis_points: number; + checkout_id: string | null; +} + + +/** + * + * @export + * @interface CustomerPortalCustomer + */ +export interface CustomerPortalCustomer { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalCustomer */ created_at: string; /** * * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalCustomer */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalCustomer */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountPercentageRepeatDuration - */ - metadata: { [key: string]: MetadataValue; }; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalCustomer */ - name: string; + email: string; /** * - * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @type {boolean} + * @memberof CustomerPortalCustomer */ - code: string | null; + email_verified: boolean; /** * * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalCustomer */ - starts_at: string | null; + name: string | null; /** * - * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @type {Address} + * @memberof CustomerPortalCustomer */ - ends_at: string | null; + billing_address: Address | null; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDuration + * @type {Array} + * @memberof CustomerPortalCustomer */ - max_redemptions: number | null; + tax_id: Array | null; /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountPercentageRepeatDuration + * + * @type {{ [key: string]: CustomerPortalOAuthAccount; }} + * @memberof CustomerPortalCustomer */ - redemptions_count: number; + oauth_accounts: { [key: string]: CustomerPortalOAuthAccount; }; +} +/** + * + * @export + * @interface CustomerPortalOAuthAccount + */ +export interface CustomerPortalOAuthAccount { /** - * The organization ID. + * * @type {string} - * @memberof DiscountPercentageRepeatDuration + * @memberof CustomerPortalOAuthAccount */ - organization_id: string; + account_id: string; /** * - * @type {Array} - * @memberof DiscountPercentageRepeatDuration + * @type {string} + * @memberof CustomerPortalOAuthAccount */ - products: Array; + account_username: string | null; } - - /** * * @export - * @interface DiscountPercentageRepeatDurationBase + * @interface CustomerSessionCodeAuthenticateRequest */ -export interface DiscountPercentageRepeatDurationBase { +export interface CustomerSessionCodeAuthenticateRequest { /** * - * @type {DiscountDuration} - * @memberof DiscountPercentageRepeatDurationBase + * @type {string} + * @memberof CustomerSessionCodeAuthenticateRequest */ - duration: DiscountDuration; + code: string; +} +/** + * + * @export + * @interface CustomerSessionCodeAuthenticateResponse + */ +export interface CustomerSessionCodeAuthenticateResponse { /** * - * @type {number} - * @memberof DiscountPercentageRepeatDurationBase + * @type {string} + * @memberof CustomerSessionCodeAuthenticateResponse */ - duration_in_months: number; + token: string; +} +/** + * + * @export + * @interface CustomerSessionCodeRequest + */ +export interface CustomerSessionCodeRequest { /** * - * @type {DiscountType} - * @memberof DiscountPercentageRepeatDurationBase + * @type {string} + * @memberof CustomerSessionCodeRequest */ - type: DiscountType; + email: string; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDurationBase + * @type {string} + * @memberof CustomerSessionCodeRequest */ - basis_points: number; + organization_id: string; +} + +/** + * + * @export + */ +export const CustomerSortProperty = { + CREATED_AT: 'created_at', + CREATED_AT2: '-created_at', + EMAIL: 'email', + EMAIL2: '-email', + NAME: 'name', + NAME2: '-name' +} as const; +export type CustomerSortProperty = typeof CustomerSortProperty[keyof typeof CustomerSortProperty]; + +/** + * + * @export + * @interface CustomerSubscription + */ +export interface CustomerSubscription { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ created_at: string; /** * * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ id: string; /** * - * @type {{ [key: string]: MetadataValue; }} - * @memberof DiscountPercentageRepeatDurationBase + * @type {number} + * @memberof CustomerSubscription */ - metadata: { [key: string]: MetadataValue; }; + amount: number | null; /** - * Name of the discount. Will be displayed to the customer when the discount is applied. + * * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ - name: string; + currency: string | null; /** * - * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @type {SubscriptionRecurringInterval} + * @memberof CustomerSubscription */ - code: string | null; + recurring_interval: SubscriptionRecurringInterval; /** * - * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @type {SubscriptionStatus} + * @memberof CustomerSubscription */ - starts_at: string | null; + status: SubscriptionStatus; /** * * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ - ends_at: string | null; + current_period_start: string; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDurationBase - */ - max_redemptions: number | null; - /** - * Number of times the discount has been redeemed. - * @type {number} - * @memberof DiscountPercentageRepeatDurationBase - */ - redemptions_count: number; - /** - * The organization ID. * @type {string} - * @memberof DiscountPercentageRepeatDurationBase + * @memberof CustomerSubscription */ - organization_id: string; -} - - -/** - * Schema to create a percentage discount that is applied on every invoice - * for a certain number of months. - * @export - * @interface DiscountPercentageRepeatDurationCreate - */ -export interface DiscountPercentageRepeatDurationCreate { + current_period_end: string | null; /** * - * @type {DiscountDuration} - * @memberof DiscountPercentageRepeatDurationCreate + * @type {boolean} + * @memberof CustomerSubscription */ - duration: DiscountDuration; + cancel_at_period_end: boolean; /** - * Number of months the discount should be applied. * - * For this to work on yearly pricing, you should multiply this by 12. - * For example, to apply the discount for 2 years, set this to 24. - * @type {number} - * @memberof DiscountPercentageRepeatDurationCreate + * @type {string} + * @memberof CustomerSubscription */ - duration_in_months: number; + started_at: string | null; /** * - * @type {DiscountType} - * @memberof DiscountPercentageRepeatDurationCreate + * @type {string} + * @memberof CustomerSubscription */ - type: DiscountType; + ended_at: string | null; /** - * Discount percentage in basis points. * - * A basis point is 1/100th of a percent. - * For example, to create a 25.5% discount, set this to 2550. - * @type {number} - * @memberof DiscountPercentageRepeatDurationCreate + * @type {string} + * @memberof CustomerSubscription */ - basis_points: number; + customer_id: string; /** - * Key-value object allowing you to store additional information. * - * The key must be a string with a maximum length of **40 characters**. - * The value must be either: - * - * * A string with a maximum length of **500 characters** - * * An integer - * * A boolean - * - * You can store up to **50 key-value pairs**. - * @type {{ [key: string]: MetadataValue1; }} - * @memberof DiscountPercentageRepeatDurationCreate - */ - metadata?: { [key: string]: MetadataValue1; }; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DiscountPercentageRepeatDurationCreate + * @memberof CustomerSubscription */ - name: string; + product_id: string; /** * * @type {string} - * @memberof DiscountPercentageRepeatDurationCreate + * @memberof CustomerSubscription */ - code?: string | null; + price_id: string; /** * * @type {string} - * @memberof DiscountPercentageRepeatDurationCreate + * @memberof CustomerSubscription */ - starts_at?: string | null; + discount_id: string | null; /** * * @type {string} - * @memberof DiscountPercentageRepeatDurationCreate + * @memberof CustomerSubscription */ - ends_at?: string | null; + checkout_id: string | null; /** * - * @type {number} - * @memberof DiscountPercentageRepeatDurationCreate + * @type {string} + * @memberof CustomerSubscription + * @deprecated */ - max_redemptions?: number | null; + user_id: string; /** - * List of product IDs the discount can be applied to. - * @type {Array} - * @memberof DiscountPercentageRepeatDurationCreate + * + * @type {CustomerSubscriptionProduct} + * @memberof CustomerSubscription */ - products?: Array | null; + product: CustomerSubscriptionProduct; /** - * The organization ID. - * @type {string} - * @memberof DiscountPercentageRepeatDurationCreate + * + * @type {ProductPrice} + * @memberof CustomerSubscription */ - organization_id?: string | null; + price: ProductPrice; } /** - * A product that a discount can be applied to. + * * @export - * @interface DiscountProduct + * @interface CustomerSubscriptionProduct */ -export interface DiscountProduct { +export interface CustomerSubscriptionProduct { /** * Creation timestamp of the object. * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ created_at: string; /** * * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ modified_at: string | null; /** * The ID of the product. * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ id: string; /** * The name of the product. * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ name: string; /** * * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ description: string | null; /** * Whether the product is a subscription tier. * @type {boolean} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ is_recurring: boolean; /** * Whether the product is archived and no longer available. * @type {boolean} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ is_archived: boolean; /** * The ID of the organization owning the product. * @type {string} - * @memberof DiscountProduct + * @memberof CustomerSubscriptionProduct */ organization_id: string; + /** + * List of prices for this product. + * @type {Array} + * @memberof CustomerSubscriptionProduct + */ + prices: Array; + /** + * List of benefits granted by the product. + * @type {Array} + * @memberof CustomerSubscriptionProduct + */ + benefits: Array; + /** + * List of medias associated to the product. + * @type {Array} + * @memberof CustomerSubscriptionProduct + */ + medias: Array; + /** + * + * @type {Organization} + * @memberof CustomerSubscriptionProduct + */ + organization: Organization; } /** * * @export */ -export const DiscountSortProperty = { - CREATED_AT: 'created_at', - CREATED_AT2: '-created_at', - NAME: 'name', - NAME2: '-name', - CODE: 'code', - CODE2: '-code', - REDEMPTIONS_COUNT: 'redemptions_count', - REDEMPTIONS_COUNT2: '-redemptions_count' +export const CustomerSubscriptionSortProperty = { + STARTED_AT: 'started_at', + STARTED_AT2: '-started_at', + AMOUNT: 'amount', + AMOUNT2: '-amount', + STATUS: 'status', + STATUS2: '-status', + ORGANIZATION: 'organization', + ORGANIZATION2: '-organization', + PRODUCT: 'product', + PRODUCT2: '-product' } as const; -export type DiscountSortProperty = typeof DiscountSortProperty[keyof typeof DiscountSortProperty]; - +export type CustomerSubscriptionSortProperty = typeof CustomerSubscriptionSortProperty[keyof typeof CustomerSubscriptionSortProperty]; /** * * @export + * @interface CustomerSubscriptionUpdate */ -export const DiscountType = { - FIXED: 'fixed', - PERCENTAGE: 'percentage' -} as const; -export type DiscountType = typeof DiscountType[keyof typeof DiscountType]; - +export interface CustomerSubscriptionUpdate { + /** + * + * @type {string} + * @memberof CustomerSubscriptionUpdate + */ + product_price_id: string; +} /** - * Schema to update a discount. + * * @export - * @interface DiscountUpdate + * @interface CustomerUpdate */ -export interface DiscountUpdate { +export interface CustomerUpdate { /** * - * @type {{ [key: string]: MetadataValue1; }} - * @memberof DiscountUpdate - */ - metadata?: { [key: string]: MetadataValue1; } | null; - /** - * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DiscountUpdate + * @memberof CustomerUpdate */ - name?: string | null; + email?: string | null; /** * - * @type {DiscountUpdateCode} - * @memberof DiscountUpdate + * @type {string} + * @memberof CustomerUpdate */ - code?: DiscountUpdateCode | null; + name?: string | null; /** * - * @type {DiscountUpdateStartsAt} - * @memberof DiscountUpdate + * @type {Address} + * @memberof CustomerUpdate */ - starts_at?: DiscountUpdateStartsAt | null; + billing_address?: Address | null; /** * - * @type {DiscountUpdateEndsAt} - * @memberof DiscountUpdate + * @type {Array} + * @memberof CustomerUpdate */ - ends_at?: DiscountUpdateEndsAt | null; + tax_id?: Array | null; +} +/** + * + * @export + * @interface DiscordGuild + */ +export interface DiscordGuild { /** * - * @type {DiscountUpdateMaxRedemptions} - * @memberof DiscountUpdate + * @type {string} + * @memberof DiscordGuild */ - max_redemptions?: DiscountUpdateMaxRedemptions | null; + name: string; /** * - * @type {DiscountDuration} - * @memberof DiscountUpdate + * @type {Array} + * @memberof DiscordGuild */ - duration?: DiscountDuration | null; + roles: Array; +} +/** + * + * @export + * @interface DiscordGuildRole + */ +export interface DiscordGuildRole { /** - * Number of months the discount should be applied. * - * For this to work on yearly pricing, you should multiply this by 12. - * For example, to apply the discount for 2 years, set this to 24. - * @type {number} - * @memberof DiscountUpdate + * @type {string} + * @memberof DiscordGuildRole */ - duration_in_months?: number | null; + id: string; /** * - * @type {DiscountType} - * @memberof DiscountUpdate + * @type {string} + * @memberof DiscordGuildRole */ - type?: DiscountType | null; + name: string; /** - * Fixed amount to discount from the invoice total. + * * @type {number} - * @memberof DiscountUpdate - */ - amount?: number | null; - /** - * The currency. Currently, only `usd` is supported. - * @type {string} - * @memberof DiscountUpdate + * @memberof DiscordGuildRole */ - currency?: string | null; + position: number; /** - * Discount percentage in basis points. * - * A basis point is 1/100th of a percent. - * For example, to create a 25.5% discount, set this to 2550. - * @type {number} - * @memberof DiscountUpdate + * @type {boolean} + * @memberof DiscordGuildRole */ - basis_points?: number | null; + is_polar_bot: boolean; /** - * List of product IDs the discount can be applied to. - * @type {Array} - * @memberof DiscountUpdate + * + * @type {string} + * @memberof DiscordGuildRole */ - products?: Array | null; + color: string; } - - -/** - * @type DiscountUpdateCode - * Code customers can use to apply the discount during checkout. Must be between 3 and 256 characters long and contain only alphanumeric characters.If not provided, the discount can only be applied via the API. - * @export - */ -export type DiscountUpdateCode = string; - /** - * @type DiscountUpdateEndsAt - * Optional timestamp after which the discount is no longer redeemable. + * @type Discount + * * @export */ -export type DiscountUpdateEndsAt = string; - +export type Discount = DiscountFixedOnceForeverDuration | DiscountFixedRepeatDuration | DiscountPercentageOnceForeverDuration | DiscountPercentageRepeatDuration; /** - * @type DiscountUpdateMaxRedemptions - * Optional maximum number of times the discount can be redeemed. + * @type DiscountCreate + * * @export */ -export type DiscountUpdateMaxRedemptions = number; +export type DiscountCreate = DiscountFixedOnceForeverDurationCreate | DiscountFixedRepeatDurationCreate | DiscountPercentageOnceForeverDurationCreate | DiscountPercentageRepeatDurationCreate; /** - * @type DiscountUpdateStartsAt - * Optional timestamp after which the discount is redeemable. + * * @export */ -export type DiscountUpdateStartsAt = string; +export const DiscountDuration = { + ONCE: 'once', + FOREVER: 'forever', + REPEATING: 'repeating' +} as const; +export type DiscountDuration = typeof DiscountDuration[keyof typeof DiscountDuration]; /** - * Schema to create a file to be associated with the downloadables benefit. + * Schema for a fixed amount discount that is applied once or forever. * @export - * @interface DownloadableFileCreate + * @interface DiscountFixedOnceForeverDuration */ -export interface DownloadableFileCreate { - /** - * The organization ID. - * @type {string} - * @memberof DownloadableFileCreate - */ - organization_id?: string | null; +export interface DiscountFixedOnceForeverDuration { /** * - * @type {string} - * @memberof DownloadableFileCreate + * @type {DiscountDuration} + * @memberof DiscountFixedOnceForeverDuration */ - name: string; + duration: DiscountDuration; /** * - * @type {string} - * @memberof DownloadableFileCreate + * @type {DiscountType} + * @memberof DiscountFixedOnceForeverDuration */ - mime_type: string; + type: DiscountType; /** * * @type {number} - * @memberof DownloadableFileCreate + * @memberof DiscountFixedOnceForeverDuration */ - size: number; + amount: number; /** * * @type {string} - * @memberof DownloadableFileCreate - */ - checksum_sha256_base64?: string | null; - /** - * - * @type {S3FileCreateMultipart} - * @memberof DownloadableFileCreate + * @memberof DiscountFixedOnceForeverDuration */ - upload: S3FileCreateMultipart; + currency: string; /** - * + * Creation timestamp of the object. * @type {string} - * @memberof DownloadableFileCreate + * @memberof DiscountFixedOnceForeverDuration */ - service: DownloadableFileCreateServiceEnum; + created_at: string; /** * * @type {string} - * @memberof DownloadableFileCreate + * @memberof DiscountFixedOnceForeverDuration */ - version?: string | null; -} - - -/** - * @export - */ -export const DownloadableFileCreateServiceEnum = { - DOWNLOADABLE: 'downloadable' -} as const; -export type DownloadableFileCreateServiceEnum = typeof DownloadableFileCreateServiceEnum[keyof typeof DownloadableFileCreateServiceEnum]; - -/** - * File to be associated with the downloadables benefit. - * @export - * @interface DownloadableFileRead - */ -export interface DownloadableFileRead { + modified_at: string | null; /** * The ID of the object. * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ id: string; /** * + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountFixedOnceForeverDuration + */ + metadata: { [key: string]: MetadataValue; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - organization_id: string; + name: string; /** * * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - name: string; + code: string | null; /** * * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - path: string; + starts_at: string | null; /** * * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - mime_type: string; + ends_at: string | null; /** * * @type {number} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - size: number; + max_redemptions: number | null; /** - * + * Number of times the discount has been redeemed. + * @type {number} + * @memberof DiscountFixedOnceForeverDuration + */ + redemptions_count: number; + /** + * The organization ID. * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDuration */ - storage_version: string | null; + organization_id: string; /** * - * @type {string} - * @memberof DownloadableFileRead + * @type {Array} + * @memberof DiscountFixedOnceForeverDuration */ - checksum_etag: string | null; + products: Array; +} + + +/** + * + * @export + * @interface DiscountFixedOnceForeverDurationBase + */ +export interface DiscountFixedOnceForeverDurationBase { /** * - * @type {string} - * @memberof DownloadableFileRead + * @type {DiscountDuration} + * @memberof DiscountFixedOnceForeverDurationBase */ - checksum_sha256_base64: string | null; + duration: DiscountDuration; /** * - * @type {string} - * @memberof DownloadableFileRead + * @type {DiscountType} + * @memberof DiscountFixedOnceForeverDurationBase */ - checksum_sha256_hex: string | null; + type: DiscountType; /** * - * @type {string} - * @memberof DownloadableFileRead + * @type {number} + * @memberof DiscountFixedOnceForeverDurationBase */ - last_modified_at: string | null; + amount: number; /** * * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDurationBase */ - version: string | null; + currency: string; /** - * + * Creation timestamp of the object. * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDurationBase */ - service: DownloadableFileReadServiceEnum; + created_at: string; /** * - * @type {boolean} - * @memberof DownloadableFileRead + * @type {string} + * @memberof DiscountFixedOnceForeverDurationBase */ - is_uploaded: boolean; + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof DiscountFixedOnceForeverDurationBase + */ + id: string; /** * + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountFixedOnceForeverDurationBase + */ + metadata: { [key: string]: MetadataValue; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDurationBase */ - created_at: string; + name: string; /** * * @type {string} - * @memberof DownloadableFileRead + * @memberof DiscountFixedOnceForeverDurationBase */ - readonly size_readable: string; -} - - -/** - * @export - */ -export const DownloadableFileReadServiceEnum = { - DOWNLOADABLE: 'downloadable' -} as const; -export type DownloadableFileReadServiceEnum = typeof DownloadableFileReadServiceEnum[keyof typeof DownloadableFileReadServiceEnum]; - -/** - * - * @export - * @interface DownloadableRead - */ -export interface DownloadableRead { + code: string | null; /** * * @type {string} - * @memberof DownloadableRead + * @memberof DiscountFixedOnceForeverDurationBase */ - id: string; + starts_at: string | null; /** * * @type {string} - * @memberof DownloadableRead + * @memberof DiscountFixedOnceForeverDurationBase */ - benefit_id: string; + ends_at: string | null; /** * - * @type {FileDownload} - * @memberof DownloadableRead + * @type {number} + * @memberof DiscountFixedOnceForeverDurationBase */ - file: FileDownload; + max_redemptions: number | null; + /** + * Number of times the discount has been redeemed. + * @type {number} + * @memberof DiscountFixedOnceForeverDurationBase + */ + redemptions_count: number; + /** + * The organization ID. + * @type {string} + * @memberof DiscountFixedOnceForeverDurationBase + */ + organization_id: string; } + + /** - * + * Schema to create a fixed amount discount that is applied once or forever. * @export - * @interface Entry + * @interface DiscountFixedOnceForeverDurationCreate */ -export interface Entry { +export interface DiscountFixedOnceForeverDurationCreate { /** * - * @type {string} - * @memberof Entry + * @type {DiscountDuration} + * @memberof DiscountFixedOnceForeverDurationCreate */ - type: string; + duration: DiscountDuration; /** * - * @type {Id} - * @memberof Entry + * @type {DiscountType} + * @memberof DiscountFixedOnceForeverDurationCreate */ - id: Id; + type: DiscountType; + /** + * Fixed amount to discount from the invoice total. + * @type {number} + * @memberof DiscountFixedOnceForeverDurationCreate + */ + amount: number; + /** + * The currency. Currently, only `usd` is supported. + * @type {string} + * @memberof DiscountFixedOnceForeverDurationCreate + */ + currency?: string; /** + * Key-value object allowing you to store additional information. * - * @type {Issue} - * @memberof Entry + * The key must be a string with a maximum length of **40 characters**. + * The value must be either: + * + * * A string with a maximum length of **500 characters** + * * An integer + * * A boolean + * + * You can store up to **50 key-value pairs**. + * @type {{ [key: string]: MetadataValue1; }} + * @memberof DiscountFixedOnceForeverDurationCreate */ - attributes: Issue; + metadata?: { [key: string]: MetadataValue1; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. + * @type {string} + * @memberof DiscountFixedOnceForeverDurationCreate + */ + name: string; /** * - * @type {Array} - * @memberof Entry + * @type {string} + * @memberof DiscountFixedOnceForeverDurationCreate */ - rewards: Array | null; + code?: string | null; /** * - * @type {PledgesTypeSummaries} - * @memberof Entry + * @type {string} + * @memberof DiscountFixedOnceForeverDurationCreate */ - pledges_summary: PledgesTypeSummaries | null; + starts_at?: string | null; /** * - * @type {Array} - * @memberof Entry + * @type {string} + * @memberof DiscountFixedOnceForeverDurationCreate */ - pledges: Array | null; -} -/** - * A price that already exists for this product. - * - * Useful when updating a product if you want to keep an existing price. - * @export - * @interface ExistingProductPrice - */ -export interface ExistingProductPrice { + ends_at?: string | null; /** * + * @type {number} + * @memberof DiscountFixedOnceForeverDurationCreate + */ + max_redemptions?: number | null; + /** + * List of product IDs the discount can be applied to. + * @type {Array} + * @memberof DiscountFixedOnceForeverDurationCreate + */ + products?: Array | null; + /** + * The organization ID. * @type {string} - * @memberof ExistingProductPrice + * @memberof DiscountFixedOnceForeverDurationCreate */ - id: string; + organization_id?: string | null; } + + /** - * + * Schema for a fixed amount discount that is applied on every invoice + * for a certain number of months. * @export - * @interface ExternalOrganization + * @interface DiscountFixedRepeatDuration */ -export interface ExternalOrganization { +export interface DiscountFixedRepeatDuration { /** * - * @type {string} - * @memberof ExternalOrganization + * @type {DiscountDuration} + * @memberof DiscountFixedRepeatDuration */ - id: string; + duration: DiscountDuration; /** * - * @type {Platforms} - * @memberof ExternalOrganization + * @type {number} + * @memberof DiscountFixedRepeatDuration */ - platform: Platforms; + duration_in_months: number; /** * - * @type {string} - * @memberof ExternalOrganization + * @type {DiscountType} + * @memberof DiscountFixedRepeatDuration */ - name: string; + type: DiscountType; /** * - * @type {string} - * @memberof ExternalOrganization + * @type {number} + * @memberof DiscountFixedRepeatDuration */ - avatar_url: string; + amount: number; /** * - * @type {boolean} - * @memberof ExternalOrganization + * @type {string} + * @memberof DiscountFixedRepeatDuration */ - is_personal: boolean; + currency: string; /** - * + * Creation timestamp of the object. * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - bio: string | null; + created_at: string; /** * * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - pretty_name: string | null; + modified_at: string | null; /** - * + * The ID of the object. * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - company: string | null; + id: string; /** * + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountFixedRepeatDuration + */ + metadata: { [key: string]: MetadataValue; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - blog: string | null; + name: string; /** * * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - location: string | null; + code: string | null; /** * * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - email: string | null; + starts_at: string | null; /** * * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - twitter_username: string | null; + ends_at: string | null; + /** + * + * @type {number} + * @memberof DiscountFixedRepeatDuration + */ + max_redemptions: number | null; + /** + * Number of times the discount has been redeemed. + * @type {number} + * @memberof DiscountFixedRepeatDuration + */ + redemptions_count: number; /** * The organization ID. * @type {string} - * @memberof ExternalOrganization + * @memberof DiscountFixedRepeatDuration */ - organization_id: string | null; + organization_id: string; + /** + * + * @type {Array} + * @memberof DiscountFixedRepeatDuration + */ + products: Array; } -/** - * @type ExternalOrganizationNameFilter - * Filter by external organization name. - * @export - */ -export type ExternalOrganizationNameFilter = Array | string; - - /** * * @export + * @interface DiscountFixedRepeatDurationBase */ -export const ExternalOrganizationSortProperty = { - CREATED_AT: 'created_at', - CREATED_AT2: '-created_at', - NAME: 'name', - NAME2: '-name' -} as const; -export type ExternalOrganizationSortProperty = typeof ExternalOrganizationSortProperty[keyof typeof ExternalOrganizationSortProperty]; - -/** - * @type FileCreate - * - * @export - */ -export type FileCreate = { service: 'downloadable' } & DownloadableFileCreate | { service: 'organization_avatar' } & OrganizationAvatarFileCreate | { service: 'product_media' } & ProductMediaFileCreate; -/** - * - * @export - * @interface FileDownload - */ -export interface FileDownload { - /** - * The ID of the object. - * @type {string} - * @memberof FileDownload - */ - id: string; - /** - * - * @type {string} - * @memberof FileDownload - */ - organization_id: string; +export interface DiscountFixedRepeatDurationBase { /** * - * @type {string} - * @memberof FileDownload + * @type {DiscountDuration} + * @memberof DiscountFixedRepeatDurationBase */ - name: string; + duration: DiscountDuration; /** * - * @type {string} - * @memberof FileDownload + * @type {number} + * @memberof DiscountFixedRepeatDurationBase */ - path: string; + duration_in_months: number; /** * - * @type {string} - * @memberof FileDownload + * @type {DiscountType} + * @memberof DiscountFixedRepeatDurationBase */ - mime_type: string; + type: DiscountType; /** * * @type {number} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - size: number; + amount: number; /** * * @type {string} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - storage_version: string | null; + currency: string; /** - * + * Creation timestamp of the object. * @type {string} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - checksum_etag: string | null; + created_at: string; /** * * @type {string} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - checksum_sha256_base64: string | null; + modified_at: string | null; /** - * + * The ID of the object. * @type {string} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - checksum_sha256_hex: string | null; + id: string; /** * - * @type {string} - * @memberof FileDownload + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountFixedRepeatDurationBase */ - last_modified_at: string | null; + metadata: { [key: string]: MetadataValue; }; /** - * - * @type {S3DownloadURL} - * @memberof FileDownload + * Name of the discount. Will be displayed to the customer when the discount is applied. + * @type {string} + * @memberof DiscountFixedRepeatDurationBase */ - download: S3DownloadURL; + name: string; /** * * @type {string} - * @memberof FileDownload + * @memberof DiscountFixedRepeatDurationBase */ - version: string | null; + code: string | null; /** * - * @type {boolean} - * @memberof FileDownload + * @type {string} + * @memberof DiscountFixedRepeatDurationBase */ - is_uploaded: boolean; + starts_at: string | null; /** * - * @type {FileServiceTypes} - * @memberof FileDownload + * @type {string} + * @memberof DiscountFixedRepeatDurationBase */ - service: FileServiceTypes; + ends_at: string | null; /** * - * @type {string} - * @memberof FileDownload + * @type {number} + * @memberof DiscountFixedRepeatDurationBase */ - readonly size_readable: string; -} - - -/** - * - * @export - * @interface FilePatch - */ -export interface FilePatch { + max_redemptions: number | null; /** - * - * @type {string} - * @memberof FilePatch + * Number of times the discount has been redeemed. + * @type {number} + * @memberof DiscountFixedRepeatDurationBase */ - name?: string | null; + redemptions_count: number; /** - * + * The organization ID. * @type {string} - * @memberof FilePatch + * @memberof DiscountFixedRepeatDurationBase */ - version?: string | null; + organization_id: string; } -/** - * @type FileRead - * - * @export - */ -export type FileRead = { service: 'downloadable' } & DownloadableFileRead | { service: 'organization_avatar' } & OrganizationAvatarFileRead | { service: 'product_media' } & ProductMediaFileRead; -/** - * - * @export - */ -export const FileServiceTypes = { - DOWNLOADABLE: 'downloadable', - PRODUCT_MEDIA: 'product_media', - ORGANIZATION_AVATAR: 'organization_avatar' -} as const; -export type FileServiceTypes = typeof FileServiceTypes[keyof typeof FileServiceTypes]; /** - * + * Schema to create a fixed amount discount that is applied on every invoice + * for a certain number of months. * @export - * @interface FileUpload + * @interface DiscountFixedRepeatDurationCreate */ -export interface FileUpload { - /** - * The ID of the object. - * @type {string} - * @memberof FileUpload - */ - id: string; - /** - * - * @type {string} - * @memberof FileUpload - */ - organization_id: string; +export interface DiscountFixedRepeatDurationCreate { /** * - * @type {string} - * @memberof FileUpload + * @type {DiscountDuration} + * @memberof DiscountFixedRepeatDurationCreate */ - name: string; + duration: DiscountDuration; /** + * Number of months the discount should be applied. * - * @type {string} - * @memberof FileUpload + * For this to work on yearly pricing, you should multiply this by 12. + * For example, to apply the discount for 2 years, set this to 24. + * @type {number} + * @memberof DiscountFixedRepeatDurationCreate */ - path: string; + duration_in_months: number; /** * - * @type {string} - * @memberof FileUpload + * @type {DiscountType} + * @memberof DiscountFixedRepeatDurationCreate */ - mime_type: string; + type: DiscountType; /** - * + * Fixed amount to discount from the invoice total. * @type {number} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - size: number; + amount: number; /** - * + * The currency. Currently, only `usd` is supported. * @type {string} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - storage_version: string | null; + currency?: string; /** + * Key-value object allowing you to store additional information. * - * @type {string} - * @memberof FileUpload - */ - checksum_etag: string | null; - /** + * The key must be a string with a maximum length of **40 characters**. + * The value must be either: * - * @type {string} - * @memberof FileUpload + * * A string with a maximum length of **500 characters** + * * An integer + * * A boolean + * + * You can store up to **50 key-value pairs**. + * @type {{ [key: string]: MetadataValue1; }} + * @memberof DiscountFixedRepeatDurationCreate */ - checksum_sha256_base64: string | null; + metadata?: { [key: string]: MetadataValue1; }; /** - * + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - checksum_sha256_hex: string | null; + name: string; /** * * @type {string} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - last_modified_at: string | null; + code?: string | null; /** * - * @type {S3FileUploadMultipart} - * @memberof FileUpload + * @type {string} + * @memberof DiscountFixedRepeatDurationCreate */ - upload: S3FileUploadMultipart; + starts_at?: string | null; /** * * @type {string} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - version: string | null; + ends_at?: string | null; /** * - * @type {boolean} - * @memberof FileUpload + * @type {number} + * @memberof DiscountFixedRepeatDurationCreate */ - is_uploaded?: boolean; + max_redemptions?: number | null; /** - * - * @type {FileServiceTypes} - * @memberof FileUpload + * List of product IDs the discount can be applied to. + * @type {Array} + * @memberof DiscountFixedRepeatDurationCreate */ - service: FileServiceTypes; + products?: Array | null; /** - * + * The organization ID. * @type {string} - * @memberof FileUpload + * @memberof DiscountFixedRepeatDurationCreate */ - readonly size_readable: string; + organization_id?: string | null; } /** - * - * @export - * @interface FileUploadCompleted + * @type DiscountIDFilter + * Filter by discount ID. + * @export */ -export interface FileUploadCompleted { - /** - * - * @type {string} - * @memberof FileUploadCompleted - */ - id: string; - /** - * - * @type {string} - * @memberof FileUploadCompleted - */ - path: string; - /** - * - * @type {Array} - * @memberof FileUploadCompleted - */ - parts: Array; -} +export type DiscountIDFilter = Array | string; + /** - * + * @type DiscountIDFilter1 + * Filter by discount ID. * @export - * @interface Funding */ -export interface Funding { - /** - * - * @type {CurrencyAmount} - * @memberof Funding - */ - funding_goal?: CurrencyAmount | null; - /** - * - * @type {CurrencyAmount} - * @memberof Funding - */ - pledges_sum?: CurrencyAmount | null; -} +export type DiscountIDFilter1 = Array | string; + /** - * + * Schema for a percentage discount that is applied once or forever. * @export - * @interface GitHubInvitesBenefitOrganization + * @interface DiscountPercentageOnceForeverDuration */ -export interface GitHubInvitesBenefitOrganization { +export interface DiscountPercentageOnceForeverDuration { /** * - * @type {string} - * @memberof GitHubInvitesBenefitOrganization + * @type {DiscountDuration} + * @memberof DiscountPercentageOnceForeverDuration */ - name: string; + duration: DiscountDuration; /** * - * @type {boolean} - * @memberof GitHubInvitesBenefitOrganization + * @type {DiscountType} + * @memberof DiscountPercentageOnceForeverDuration */ - is_personal: boolean; + type: DiscountType; /** * + * @type {number} + * @memberof DiscountPercentageOnceForeverDuration + */ + basis_points: number; + /** + * Creation timestamp of the object. * @type {string} - * @memberof GitHubInvitesBenefitOrganization + * @memberof DiscountPercentageOnceForeverDuration */ - plan_name: string; + created_at: string; /** * - * @type {boolean} - * @memberof GitHubInvitesBenefitOrganization + * @type {string} + * @memberof DiscountPercentageOnceForeverDuration */ - is_free: boolean; -} -/** - * - * @export - * @interface GitHubInvitesBenefitRepositories - */ -export interface GitHubInvitesBenefitRepositories { + modified_at: string | null; /** - * - * @type {Array} - * @memberof GitHubInvitesBenefitRepositories + * The ID of the object. + * @type {string} + * @memberof DiscountPercentageOnceForeverDuration */ - repositories: Array; + id: string; /** * - * @type {Array} - * @memberof GitHubInvitesBenefitRepositories + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountPercentageOnceForeverDuration */ - organizations: Array; -} -/** - * - * @export - * @interface GitHubInvitesBenefitRepository - */ -export interface GitHubInvitesBenefitRepository { + metadata: { [key: string]: MetadataValue; }; /** - * + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof GitHubInvitesBenefitRepository + * @memberof DiscountPercentageOnceForeverDuration */ - repository_owner: string; + name: string; /** * * @type {string} - * @memberof GitHubInvitesBenefitRepository + * @memberof DiscountPercentageOnceForeverDuration */ - repository_name: string; -} -/** - * - * @export - * @interface GithubUser - */ -export interface GithubUser { + code: string | null; /** * * @type {string} - * @memberof GithubUser + * @memberof DiscountPercentageOnceForeverDuration */ - username: string; + starts_at: string | null; /** * * @type {string} - * @memberof GithubUser + * @memberof DiscountPercentageOnceForeverDuration */ - avatar_url: string; -} -/** - * - * @export - * @interface HTTPValidationError - */ -export interface HTTPValidationError { + ends_at: string | null; /** * - * @type {Array} - * @memberof HTTPValidationError + * @type {number} + * @memberof DiscountPercentageOnceForeverDuration */ - detail?: Array; -} -/** - * @type Id - * @export - */ -export type Id = string; - -/** - * - * @export - * @interface InstallationCreate - */ -export interface InstallationCreate { + max_redemptions: number | null; /** - * + * Number of times the discount has been redeemed. * @type {number} - * @memberof InstallationCreate + * @memberof DiscountPercentageOnceForeverDuration */ - installation_id: number; + redemptions_count: number; /** * The organization ID. * @type {string} - * @memberof InstallationCreate + * @memberof DiscountPercentageOnceForeverDuration */ organization_id: string; + /** + * + * @type {Array} + * @memberof DiscountPercentageOnceForeverDuration + */ + products: Array; } -/** - * - * @export - */ -export const Interval = { - YEAR: 'year', - MONTH: 'month', - WEEK: 'week', - DAY: 'day', - HOUR: 'hour' -} as const; -export type Interval = typeof Interval[keyof typeof Interval]; /** * * @export - * @interface IntrospectTokenResponse + * @interface DiscountPercentageOnceForeverDurationBase */ -export interface IntrospectTokenResponse { +export interface DiscountPercentageOnceForeverDurationBase { /** * - * @type {boolean} - * @memberof IntrospectTokenResponse + * @type {DiscountDuration} + * @memberof DiscountPercentageOnceForeverDurationBase */ - active: boolean; + duration: DiscountDuration; /** * - * @type {string} - * @memberof IntrospectTokenResponse + * @type {DiscountType} + * @memberof DiscountPercentageOnceForeverDurationBase */ - client_id: string; + type: DiscountType; /** * + * @type {number} + * @memberof DiscountPercentageOnceForeverDurationBase + */ + basis_points: number; + /** + * Creation timestamp of the object. * @type {string} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - token_type: IntrospectTokenResponseTokenTypeEnum; + created_at: string; /** * * @type {string} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - scope: string; + modified_at: string | null; + /** + * The ID of the object. + * @type {string} + * @memberof DiscountPercentageOnceForeverDurationBase + */ + id: string; /** * - * @type {SubType} - * @memberof IntrospectTokenResponse + * @type {{ [key: string]: MetadataValue; }} + * @memberof DiscountPercentageOnceForeverDurationBase */ - sub_type: SubType; + metadata: { [key: string]: MetadataValue; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. + * @type {string} + * @memberof DiscountPercentageOnceForeverDurationBase + */ + name: string; /** * * @type {string} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - sub: string; + code: string | null; /** * * @type {string} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - aud: string; + starts_at: string | null; /** * * @type {string} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - iss: string; + ends_at: string | null; /** * * @type {number} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - exp: number; + max_redemptions: number | null; /** - * + * Number of times the discount has been redeemed. * @type {number} - * @memberof IntrospectTokenResponse + * @memberof DiscountPercentageOnceForeverDurationBase */ - iat: number; -} - + redemptions_count: number; + /** + * The organization ID. + * @type {string} + * @memberof DiscountPercentageOnceForeverDurationBase + */ + organization_id: string; +} -/** - * @export - */ -export const IntrospectTokenResponseTokenTypeEnum = { - ACCESS_TOKEN: 'access_token', - REFRESH_TOKEN: 'refresh_token' -} as const; -export type IntrospectTokenResponseTokenTypeEnum = typeof IntrospectTokenResponseTokenTypeEnum[keyof typeof IntrospectTokenResponseTokenTypeEnum]; /** - * + * Schema to create a percentage discount that is applied once or forever. * @export - * @interface Issue + * @interface DiscountPercentageOnceForeverDurationCreate */ -export interface Issue { +export interface DiscountPercentageOnceForeverDurationCreate { /** * - * @type {string} - * @memberof Issue + * @type {DiscountDuration} + * @memberof DiscountPercentageOnceForeverDurationCreate */ - id: string; + duration: DiscountDuration; /** * - * @type {Platforms} - * @memberof Issue + * @type {DiscountType} + * @memberof DiscountPercentageOnceForeverDurationCreate */ - platform: Platforms; + type: DiscountType; /** - * GitHub #number + * Discount percentage in basis points. + * + * A basis point is 1/100th of a percent. + * For example, to create a 25.5% discount, set this to 2550. * @type {number} - * @memberof Issue + * @memberof DiscountPercentageOnceForeverDurationCreate */ - number: number; + basis_points: number; /** - * GitHub issue title + * Key-value object allowing you to store additional information. + * + * The key must be a string with a maximum length of **40 characters**. + * The value must be either: + * + * * A string with a maximum length of **500 characters** + * * An integer + * * A boolean + * + * You can store up to **50 key-value pairs**. + * @type {{ [key: string]: MetadataValue1; }} + * @memberof DiscountPercentageOnceForeverDurationCreate + */ + metadata?: { [key: string]: MetadataValue1; }; + /** + * Name of the discount. Will be displayed to the customer when the discount is applied. * @type {string} - * @memberof Issue + * @memberof DiscountPercentageOnceForeverDurationCreate */ - title: string; + name: string; /** * * @type {string} - * @memberof Issue + * @memberof DiscountPercentageOnceForeverDurationCreate */ - body?: string | null; + code?: string | null; /** * - * @type {number} - * @memberof Issue + * @type {string} + * @memberof DiscountPercentageOnceForeverDurationCreate */ - comments?: number | null; + starts_at?: string | null; /** * - * @type {Array