diff --git a/app/account_linking.py b/app/account_linking.py index 2f42aa0bc..43b24c9a4 100644 --- a/app/account_linking.py +++ b/app/account_linking.py @@ -35,6 +35,7 @@ class SLPlanType(Enum): Free = 1 Premium = 2 + PremiumLifetime = 3 @dataclass @@ -75,6 +76,7 @@ def send_user_plan_changed_event(partner_user: PartnerUser) -> Optional[int]: def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan): sub = PartnerSubscription.get_by(partner_user_id=partner_user.id) + is_lifetime = plan.type == SLPlanType.PremiumLifetime if plan.type == SLPlanType.Free: if sub is not None: LOG.i( @@ -83,25 +85,30 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan): PartnerSubscription.delete(sub.id) agent.record_custom_event("PlanChange", {"plan": "free"}) else: + end_time = plan.expiration + if plan.type == SLPlanType.PremiumLifetime: + end_time = None if sub is None: LOG.i( f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]" ) create_partner_subscription( partner_user=partner_user, - expiration=plan.expiration, + expiration=end_time, + lifetime=is_lifetime, msg="Upgraded via partner. User did not have a previous partner subscription", ) agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"}) else: - if sub.end_at != plan.expiration: + if sub.end_at != plan.expiration or sub.lifetime != is_lifetime: LOG.i( f"Updating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]" ) agent.record_custom_event( "PlanChange", {"plan": "premium", "type": "extension"} ) - sub.end_at = plan.expiration + sub.end_at = plan.expiration if not is_lifetime else None + sub.lifetime = is_lifetime emit_user_audit_log( user=partner_user.user, action=UserAuditLogAction.SubscriptionExtended, diff --git a/app/models.py b/app/models.py index 12013381a..8186a7e1e 100644 --- a/app/models.py +++ b/app/models.py @@ -3778,7 +3778,8 @@ class PartnerSubscription(Base, ModelMixin): ) # when the partner subscription ends - end_at = sa.Column(ArrowType, nullable=False, index=True) + end_at = sa.Column(ArrowType, nullable=True, index=True) + lifetime = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") partner_user = orm.relationship(PartnerUser) @@ -3800,7 +3801,9 @@ def find_by_user_id(cls, user_id: int) -> Optional[PartnerSubscription]: return None def is_active(self): - return self.end_at > arrow.now().shift(days=-_PARTNER_SUBSCRIPTION_GRACE_DAYS) + return self.lifetime or self.end_at > arrow.now().shift( + days=-_PARTNER_SUBSCRIPTION_GRACE_DAYS + ) # endregion diff --git a/app/partner_user_utils.py b/app/partner_user_utils.py index ba665f07d..6254ba6f7 100644 --- a/app/partner_user_utils.py +++ b/app/partner_user_utils.py @@ -33,12 +33,14 @@ def create_partner_user( def create_partner_subscription( partner_user: PartnerUser, - expiration: Optional[Arrow], + expiration: Optional[Arrow] = None, + lifetime: bool = False, msg: Optional[str] = None, ) -> PartnerSubscription: instance = PartnerSubscription.create( partner_user_id=partner_user.id, end_at=expiration, + lifetime=lifetime, ) message = "User upgraded through partner subscription" diff --git a/app/proton/proton_client.py b/app/proton/proton_client.py index f06325b8b..8c086ec8e 100644 --- a/app/proton/proton_client.py +++ b/app/proton/proton_client.py @@ -16,6 +16,7 @@ PLAN_FREE = 1 PLAN_PREMIUM = 2 +PLAN_PREMIUM_LIFETIME = 3 @dataclass @@ -112,10 +113,13 @@ def get_user(self) -> Optional[UserInformation]: if plan_value == PLAN_FREE: plan = SLPlan(type=SLPlanType.Free, expiration=None) elif plan_value == PLAN_PREMIUM: + expiration = info.get("Expiration", "1") plan = SLPlan( type=SLPlanType.Premium, - expiration=Arrow.fromtimestamp(info["PlanExpiration"], tzinfo="utc"), + expiration=Arrow.fromtimestamp(expiration, tzinfo="utc"), ) + elif plan_value == PLAN_PREMIUM_LIFETIME: + plan = SLPlan(SLPlanType.PremiumLifetime, expiration=None) else: raise Exception(f"Invalid value for plan: {plan_value}") diff --git a/migrations/versions/2024_112619_085f77996ce3_.py b/migrations/versions/2024_112619_085f77996ce3_.py new file mode 100644 index 000000000..2aba0eecc --- /dev/null +++ b/migrations/versions/2024_112619_085f77996ce3_.py @@ -0,0 +1,35 @@ +"""empty message + +Revision ID: 085f77996ce3 +Revises: 0f3ee15b0014 +Create Date: 2024-11-26 19:20:32.227899 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '085f77996ce3' +down_revision = '0f3ee15b0014' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('partner_subscription', sa.Column('lifetime', sa.Boolean(), server_default='0', nullable=False)) + op.alter_column('partner_subscription', 'end_at', + existing_type=postgresql.TIMESTAMP(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('partner_subscription', 'end_at', + existing_type=postgresql.TIMESTAMP(), + nullable=False) + op.drop_column('partner_subscription', 'lifetime') + # ### end Alembic commands ### diff --git a/templates/dashboard/setting.html b/templates/dashboard/setting.html index 7f5ef64cd..767ec38c6 100644 --- a/templates/dashboard/setting.html +++ b/templates/dashboard/setting.html @@ -79,7 +79,13 @@ {% endif %} - {% if partner_sub %}
Premium subscription managed by {{ partner_name }}.
{% endif %} + {% if partner_sub %} + {% if partner_sub.lifetime %} +
Premium lifetime subscription managed by {{ partner_name }}.
+ {% else %} +
Premium subscription managed by {{ partner_name }}.
+ {% endif %} + {% endif %} {% elif current_user.in_trial() %} Your Premium trial expires {{ current_user.trial_end | dt }}. {% else %} diff --git a/tests/proton/test_account_linking.py b/tests/proton/test_account_linking.py new file mode 100644 index 000000000..d2511c0c0 --- /dev/null +++ b/tests/proton/test_account_linking.py @@ -0,0 +1,100 @@ +import arrow + +from app.account_linking import ( + SLPlan, + SLPlanType, + set_plan_for_partner_user, +) +from app.db import Session +from app.models import User, PartnerUser, PartnerSubscription +from app.proton.utils import get_proton_partner +from app.utils import random_string +from tests.utils import random_email + +partner_user_id: int = 0 + + +def setup_module(): + global partner_user_id + email = random_email() + external_id = random_string() + sl_user = User.create(email, commit=True) + partner_user_id = PartnerUser.create( + user_id=sl_user.id, + partner_id=get_proton_partner().id, + external_user_id=external_id, + partner_email=email, + commit=True, + ).id + + +def setup_function(func): + Session.query(PartnerSubscription).delete() + + +def test_free_plan_removes_sub(): + pu = PartnerUser.get(partner_user_id) + sub_id = PartnerSubscription.create( + partner_user_id=partner_user_id, + end_at=arrow.utcnow(), + lifetime=False, + commit=True, + ).id + set_plan_for_partner_user(pu, plan=SLPlan(type=SLPlanType.Free, expiration=None)) + assert PartnerSubscription.get(sub_id) is None + + +def test_premium_plan_updates_expiration(): + pu = PartnerUser.get(partner_user_id) + sub_id = PartnerSubscription.create( + partner_user_id=partner_user_id, + end_at=arrow.utcnow(), + lifetime=False, + commit=True, + ).id + new_expiration = arrow.utcnow().shift(days=+10) + set_plan_for_partner_user( + pu, plan=SLPlan(type=SLPlanType.Premium, expiration=new_expiration) + ) + assert PartnerSubscription.get(sub_id).end_at == new_expiration + + +def test_premium_plan_creates_sub(): + pu = PartnerUser.get(partner_user_id) + new_expiration = arrow.utcnow().shift(days=+10) + set_plan_for_partner_user( + pu, plan=SLPlan(type=SLPlanType.Premium, expiration=new_expiration) + ) + assert ( + PartnerSubscription.get_by(partner_user_id=partner_user_id).end_at + == new_expiration + ) + + +def test_lifetime_creates_sub(): + pu = PartnerUser.get(partner_user_id) + new_expiration = arrow.utcnow().shift(days=+10) + set_plan_for_partner_user( + pu, plan=SLPlan(type=SLPlanType.PremiumLifetime, expiration=new_expiration) + ) + sub = PartnerSubscription.get_by(partner_user_id=partner_user_id) + assert sub is not None + assert sub.end_at is None + assert sub.lifetime + + +def test_lifetime_updates_sub(): + pu = PartnerUser.get(partner_user_id) + sub_id = PartnerSubscription.create( + partner_user_id=partner_user_id, + end_at=arrow.utcnow(), + lifetime=False, + commit=True, + ).id + set_plan_for_partner_user( + pu, plan=SLPlan(type=SLPlanType.PremiumLifetime, expiration=arrow.utcnow()) + ) + sub = PartnerSubscription.get(sub_id) + assert sub is not None + assert sub.end_at is None + assert sub.lifetime