diff --git a/app/config.py b/app/config.py index 1962976e7..0ad1d37e4 100644 --- a/app/config.py +++ b/app/config.py @@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]: return res -PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") -PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( - "PARTNER_DOMAIN_VALIDATION_PREFIXES" +PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict( + "PARTNER_DNS_CUSTOM_DOMAINS" +) +PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( + "PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES" ) MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get( diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 4dbf201a9..9ebbd636b 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -5,6 +5,7 @@ from app.constants import DMARC_RECORD from app.db import Session from app.dns_utils import ( + MxRecord, DNSClient, is_mx_equivalent, get_network_dns_client, @@ -28,10 +29,10 @@ def __init__( ): self.dkim_domain = dkim_domain self._dns_client = dns_client - self._partner_domains = partner_domains or config.PARTNER_DOMAINS + self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS self._partner_domain_validation_prefixes = ( partner_domains_validation_prefixes - or config.PARTNER_DOMAIN_VALIDATION_PREFIXES + or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES ) def get_ownership_verification_record(self, domain: CustomDomain) -> str: @@ -43,6 +44,29 @@ def get_ownership_verification_record(self, domain: CustomDomain) -> str: prefix = self._partner_domain_validation_prefixes[domain.partner_id] return f"{prefix}-verification={domain.ownership_txt_token}" + def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]: + records = [] + if domain.partner_id is not None and domain.partner_id in self._partner_domains: + domain = self._partner_domains[domain.partner_id] + records.append(MxRecord(10, f"mx1.{domain}.")) + records.append(MxRecord(20, f"mx2.{domain}.")) + else: + # Default ones + for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY: + records.append(MxRecord(priority, domain)) + + return records + + def get_expected_spf_domain(self, domain: CustomDomain) -> str: + if domain.partner_id is not None and domain.partner_id in self._partner_domains: + return self._partner_domains[domain.partner_id] + else: + return config.EMAIL_DOMAIN + + def get_expected_spf_record(self, domain: CustomDomain) -> str: + spf_domain = self.get_expected_spf_domain(domain) + return f"v=spf1 include:{spf_domain} ~all" + def get_dkim_records(self, domain: CustomDomain) -> {str: str}: """ Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, @@ -116,11 +140,12 @@ def validate_mx_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) + expected_mx_records = self.get_expected_mx_records(custom_domain) - if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): + if not is_mx_equivalent(mx_domains, expected_mx_records): return DomainValidationResult( success=False, - errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], + errors=[f"{record.priority} {record.domain}" for record in mx_domains], ) else: custom_domain.verified = True @@ -131,7 +156,8 @@ def validate_spf_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) - if config.EMAIL_DOMAIN in spf_domains: + expected_spf_domain = self.get_expected_spf_domain(custom_domain) + if expected_spf_domain in spf_domains: custom_domain.spf_verified = True Session.commit() return DomainValidationResult(success=True, errors=[]) diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 0911a748d..2b1ac32fb 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id): custom_domain.ownership_txt_token = random_string(30) Session.commit() - spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" - domain_validator = CustomDomainValidation(EMAIL_DOMAIN) csrf_form = CSRFValidationForm() @@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id): ownership_record=domain_validator.get_ownership_verification_record( custom_domain ), + expected_mx_records=domain_validator.get_expected_mx_records(custom_domain), dkim_records=domain_validator.get_dkim_records(custom_domain), + spf_record=domain_validator.get_expected_spf_record(custom_domain), dmarc_record=DMARC_RECORD, **locals(), ) diff --git a/app/dns_utils.py b/app/dns_utils.py index 2ce699340..995a1e17d 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import List, Tuple, Optional +from dataclasses import dataclass +from typing import List, Optional import dns.resolver @@ -8,8 +9,14 @@ _include_spf = "include:" +@dataclass +class MxRecord: + priority: int + domain: str + + def is_mx_equivalent( - mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] + mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord] ) -> bool: """ Compare mx_domains with ref_mx_domains to see if they are equivalent. @@ -18,14 +25,14 @@ def is_mx_equivalent( The priority order is taken into account but not the priority number. For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] """ - mx_domains = sorted(mx_domains, key=lambda x: x[0]) - ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) + mx_domains = sorted(mx_domains, key=lambda x: x.priority) + ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority) if len(mx_domains) < len(ref_mx_domains): return False - for i in range(len(ref_mx_domains)): - if mx_domains[i][1] != ref_mx_domains[i][1]: + for actual, expected in zip(mx_domains, ref_mx_domains): + if actual.domain != expected.domain: return False return True @@ -37,7 +44,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]: pass @abstractmethod - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: pass def get_spf_domain(self, hostname: str) -> List[str]: @@ -81,7 +88,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]: except Exception: return None - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: """ return list of (priority, domain name) sorted by priority (lowest priority first) domain name ends with a "." at the end. @@ -92,8 +99,8 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: for a in answers: record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' parts = record.split(" ") - ret.append((int(parts[0]), parts[1])) - return sorted(ret, key=lambda x: x[0]) + ret.append(MxRecord(priority=int(parts[0]), domain=parts[1])) + return sorted(ret, key=lambda x: x.priority) except Exception: return [] @@ -112,14 +119,14 @@ def get_txt_record(self, hostname: str) -> List[str]: class InMemoryDNSClient(DNSClient): def __init__(self): self.cname_records: dict[str, Optional[str]] = {} - self.mx_records: dict[str, List[Tuple[int, str]]] = {} + self.mx_records: dict[str, List[MxRecord]] = {} self.spf_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {} def set_cname_record(self, hostname: str, cname: str): self.cname_records[hostname] = cname - def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): + def set_mx_records(self, hostname: str, mx_list: List[MxRecord]): self.mx_records[hostname] = mx_list def set_txt_record(self, hostname: str, txt_list: List[str]): @@ -128,9 +135,9 @@ def set_txt_record(self, hostname: str, txt_list: List[str]): def get_cname_record(self, hostname: str) -> Optional[str]: return self.cname_records.get(hostname) - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: mx_list = self.mx_records.get(hostname, []) - return sorted(mx_list, key=lambda x: x[0]) + return sorted(mx_list, key=lambda x: x.priority) def get_txt_record(self, hostname: str) -> List[str]: return self.txt_records.get(hostname, []) @@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient: return NetworkDNSClient(NAMESERVERS) -def get_mx_domains(hostname: str) -> [(int, str)]: +def get_mx_domains(hostname: str) -> List[MxRecord]: return get_network_dns_client().get_mx_domains(hostname) diff --git a/app/email_utils.py b/app/email_utils.py index 5ff34d06d..ca5aa0414 100644 --- a/app/email_utils.py +++ b/app/email_utils.py @@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]: """ priority_domains = get_mx_domains(domain) - return [d[:-1] for _, d in priority_domains] + return [d.domain[:-1] for d in priority_domains] def personal_email_already_used(email_address: str) -> bool: diff --git a/app/models.py b/app/models.py index c86c5df65..092c10608 100644 --- a/app/models.py +++ b/app/models.py @@ -2766,9 +2766,9 @@ def is_proton(self) -> bool: from app.email_utils import get_email_local_part - mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email)) + mx_domains = get_mx_domains(get_email_local_part(self.email)) # Proton is the first domain - if mx_domains and mx_domains[0][1] in ( + if mx_domains and mx_domains[0].domain in ( "mail.protonmail.ch.", "mailsec.protonmail.ch.", ): diff --git a/cron.py b/cron.py index bfc150d94..f8bb09f4a 100644 --- a/cron.py +++ b/cron.py @@ -14,6 +14,7 @@ from app import s3, config from app.alias_utils import nb_email_log_for_mailbox from app.api.views.apple import verify_receipt +from app.custom_domain_validation import CustomDomainValidation from app.db import Session from app.dns_utils import get_mx_domains, is_mx_equivalent from app.email_utils import ( @@ -905,9 +906,11 @@ def check_custom_domain(): LOG.i("custom domain has been deleted") -def check_single_custom_domain(custom_domain): +def check_single_custom_domain(custom_domain: CustomDomain): mx_domains = get_mx_domains(custom_domain.domain) - if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): + validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN) + expected_custom_domains = validator.get_expected_mx_records(custom_domain) + if not is_mx_equivalent(mx_domains, expected_custom_domains): user = custom_domain.user LOG.w( "The MX record is not correctly set for %s %s %s", diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 4058f5ea6..e100feead 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -91,7 +91,8 @@