Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve MX and SPF domain handling #2246

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]:
return res


PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
"PARTNER_DNS_CUSTOM_DOMAINS"
)
PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
)

MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
Expand Down
36 changes: 31 additions & 5 deletions app/custom_domain_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
MxRecord,
DNSClient,
is_mx_equivalent,
get_network_dns_client,
Expand All @@ -28,10 +29,10 @@ def __init__(
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS
self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
)

def get_ownership_verification_record(self, domain: CustomDomain) -> str:
Expand All @@ -43,6 +44,29 @@ def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"

def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))

return records

def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN

def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"

def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
Expand Down Expand Up @@ -116,11 +140,12 @@ def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)

if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
errors=[f"{record.priority} {record.domain}" for record in mx_domains],
)
else:
custom_domain.verified = True
Expand All @@ -131,7 +156,8 @@ def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
Expand Down
4 changes: 2 additions & 2 deletions app/dashboard/views/domain_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30)
Session.commit()

spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"

domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()

Expand Down Expand Up @@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
Expand Down
37 changes: 22 additions & 15 deletions app/dns_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
from dataclasses import dataclass
from typing import List, Optional

import dns.resolver

Expand All @@ -8,8 +9,14 @@
_include_spf = "include:"


@dataclass
class MxRecord:
priority: int
domain: str


def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
Expand All @@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)

if len(mx_domains) < len(ref_mx_domains):
return False

for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False

return True
Expand All @@ -37,7 +44,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
pass

@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass

def get_spf_domain(self, hostname: str) -> List[str]:
Expand Down Expand Up @@ -81,7 +88,7 @@ def get_cname_record(self, hostname: str) -> Optional[str]:
except Exception:
return None

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
Expand All @@ -92,8 +99,8 @@ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x.priority)
except Exception:
return []

Expand All @@ -112,14 +119,14 @@ def get_txt_record(self, hostname: str) -> List[str]:
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}

def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname

def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list

def set_txt_record(self, hostname: str, txt_list: List[str]):
Expand All @@ -128,9 +135,9 @@ def set_txt_record(self, hostname: str, txt_list: List[str]):
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)

def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
return sorted(mx_list, key=lambda x: x.priority)

def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
Expand All @@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)


def get_mx_domains(hostname: str) -> [(int, str)]:
def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname)
2 changes: 1 addition & 1 deletion app/email_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
"""
priority_domains = get_mx_domains(domain)

return [d[:-1] for _, d in priority_domains]
return [d.domain[:-1] for d in priority_domains]


def personal_email_already_used(email_address: str) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,9 +2766,9 @@ def is_proton(self) -> bool:

from app.email_utils import get_email_local_part

mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email))
mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain
if mx_domains and mx_domains[0][1] in (
if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.",
"mailsec.protonmail.ch.",
):
Expand Down
7 changes: 5 additions & 2 deletions cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import (
Expand Down Expand Up @@ -905,9 +906,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted")


def check_single_custom_domain(custom_domain):
def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user
LOG.w(
"The MX record is not correctly set for %s %s %s",
Expand Down
8 changes: 5 additions & 3 deletions templates/dashboard/domain_detail/dns.html
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,24 @@ <h1 class="h2">{{ custom_domain.domain }}</h1>
<br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}

{% for record in expected_mx_records %}

<div class="mb-3 p-3 dns-record">
Record: MX
<br />
Domain: {{ custom_domain.domain }} or
<b>@</b>
<br />
Priority: {{ priority }}
Priority: {{ record.priority }}
<br />
Target: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em>
data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div>
{% endfor %}

<form method="post" action="#mx-form">
{{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx">
Expand Down
Loading
Loading