diff --git a/src/handlers/dns/DnsStreamHandler.cpp b/src/handlers/dns/DnsStreamHandler.cpp index 08022ef3d..fe2bd694c 100644 --- a/src/handlers/dns/DnsStreamHandler.cpp +++ b/src/handlers/dns/DnsStreamHandler.cpp @@ -163,7 +163,8 @@ void DnsStreamHandler::process_udp_packet_cb(pcpp::Packet &payload, PacketDirect if (metric_port) { DnsLayer dnsLayer(udpLayer, &payload); if (!_filtering(dnsLayer, dir, l3, pcpp::UDP, metric_port, stamp)) { - _metrics->process_dns_layer(dnsLayer, dir, l3, pcpp::UDP, flowkey, metric_port, stamp); + _metrics->process_dns_layer(dnsLayer, dir, l3, pcpp::UDP, flowkey, metric_port, _static_suffix_size, stamp); + _static_suffix_size = 0; // signal for chained stream handlers, if we have any udp_signal(payload, dir, l3, flowkey, stamp); } @@ -243,7 +244,8 @@ void DnsStreamHandler::tcp_message_ready_cb(int8_t side, const pcpp::TcpStreamDa pcpp::Packet dummy_packet; DnsLayer dnsLayer(data.get(), size, nullptr, &dummy_packet); if (!_filtering(dnsLayer, dir, l3Type, pcpp::UDP, port, stamp)) { - _metrics->process_dns_layer(dnsLayer, dir, l3Type, pcpp::TCP, flowKey, port, stamp); + _metrics->process_dns_layer(dnsLayer, dir, l3Type, pcpp::TCP, flowKey, port, _static_suffix_size, stamp); + _static_suffix_size = 0; } // data is freed upon return }; @@ -319,9 +321,10 @@ bool DnsStreamHandler::_filtering(DnsLayer &payload, [[maybe_unused]] PacketDire std::string qname_ci{payload.getFirstQuery()->getName()}; std::transform(qname_ci.begin(), qname_ci.end(), qname_ci.begin(), [](unsigned char c) { return std::tolower(c); }); - for (auto fqn : _f_qnames) { + for (const auto &fqn : _f_qnames) { // if it matched, we know we are not filtering if (endsWith(qname_ci, fqn)) { + _static_suffix_size = fqn.size(); goto will_not_filter; } } @@ -532,7 +535,7 @@ void DnsMetricsBucket::process_dnstap(bool deep, const dnstap::Dnstap &payload) process_dns_layer(deep, dpayload, l3, l4, port); } } -void DnsMetricsBucket::process_dns_layer(bool deep, DnsLayer &payload, pcpp::ProtocolType l3, Protocol l4, uint16_t port) +void DnsMetricsBucket::process_dns_layer(bool deep, DnsLayer &payload, pcpp::ProtocolType l3, Protocol l4, uint16_t port, size_t suffix_size) { std::unique_lock lock(_mutex); @@ -626,7 +629,7 @@ void DnsMetricsBucket::process_dns_layer(bool deep, DnsLayer &payload, pcpp::Pro } } - auto aggDomain = aggregateDomain(name); + auto aggDomain = aggregateDomain(name, suffix_size); _dns_topQname2.update(std::string(aggDomain.first)); if (aggDomain.second.size()) { _dns_topQname3.update(std::string(aggDomain.second)); @@ -788,12 +791,12 @@ void DnsMetricsBucket::process_filtered() } // the general metrics manager entry point (both UDP and TCP) -void DnsMetricsManager::process_dns_layer(DnsLayer &payload, PacketDirection dir, pcpp::ProtocolType l3, pcpp::ProtocolType l4, uint32_t flowkey, uint16_t port, timespec stamp) +void DnsMetricsManager::process_dns_layer(DnsLayer &payload, PacketDirection dir, pcpp::ProtocolType l3, pcpp::ProtocolType l4, uint32_t flowkey, uint16_t port, size_t suffix_size, timespec stamp) { // base event new_event(stamp); // process in the "live" bucket. this will parse the resources if we are deep sampling - live_bucket()->process_dns_layer(_deep_sampling_now, payload, l3, static_cast(l4), port); + live_bucket()->process_dns_layer(_deep_sampling_now, payload, l3, static_cast(l4), port, suffix_size); if (group_enabled(group::DnsMetrics::DnsTransactions)) { // handle dns transactions (query/response pairs) diff --git a/src/handlers/dns/DnsStreamHandler.h b/src/handlers/dns/DnsStreamHandler.h index 9b9013d58..6ab917896 100644 --- a/src/handlers/dns/DnsStreamHandler.h +++ b/src/handlers/dns/DnsStreamHandler.h @@ -158,7 +158,7 @@ class DnsMetricsBucket final : public visor::AbstractMetricsBucket void to_prometheus(std::stringstream &out, Metric::LabelMap add_labels = {}) const override; void process_filtered(); - void process_dns_layer(bool deep, DnsLayer &payload, pcpp::ProtocolType l3, Protocol l4, uint16_t port); + void process_dns_layer(bool deep, DnsLayer &payload, pcpp::ProtocolType l3, Protocol l4, uint16_t port, size_t suffix_size = 0); void process_dns_layer(pcpp::ProtocolType l3, Protocol l4, QR side, uint16_t port); void process_dnstap(bool deep, const dnstap::Dnstap &payload); @@ -200,7 +200,7 @@ class DnsMetricsManager final : public visor::AbstractMetricsManager _f_enabled; uint16_t _f_rcode{0}; std::vector _f_qnames; + size_t _static_suffix_size{0}; std::bitset _f_dnstap_types; static const inline StreamMetricsHandler::GroupDefType _group_defs = { diff --git a/src/handlers/dns/dns.cpp b/src/handlers/dns/dns.cpp index c145b8df7..d4dee9c4d 100644 --- a/src/handlers/dns/dns.cpp +++ b/src/handlers/dns/dns.cpp @@ -6,7 +6,7 @@ namespace visor::handler::dns { -AggDomainResult aggregateDomain(const std::string &domain) +AggDomainResult aggregateDomain(const std::string &domain, size_t suffix_size) { std::string_view qname2(domain); @@ -18,7 +18,9 @@ AggDomainResult aggregateDomain(const std::string &domain) return AggDomainResult(qname2, qname3); } std::size_t endDot = std::string::npos; - if (domain.back() == '.') { + if (suffix_size > 0 && domain.size() > suffix_size) { + endDot = domain.size() - suffix_size; + } else if (domain.back() == '.') { endDot = domain.size() - 2; } auto first_dot = domain.rfind('.', endDot); diff --git a/src/handlers/dns/dns.h b/src/handlers/dns/dns.h index 2a3428d17..af0a727e6 100644 --- a/src/handlers/dns/dns.h +++ b/src/handlers/dns/dns.h @@ -13,7 +13,7 @@ namespace visor::handler::dns { typedef std::pair AggDomainResult; -AggDomainResult aggregateDomain(const std::string &domain); +AggDomainResult aggregateDomain(const std::string &domain, size_t suffix_size = 0); enum QR { query = 0, diff --git a/src/handlers/dns/tests/test_dns.cpp b/src/handlers/dns/tests/test_dns.cpp index 9ce59cb88..d87d93c54 100644 --- a/src/handlers/dns/tests/test_dns.cpp +++ b/src/handlers/dns/tests/test_dns.cpp @@ -72,4 +72,47 @@ TEST_CASE("DNS Utilities", "[dns]") CHECK(result.first == ".b.c"); CHECK(result.second == ""); } + + SECTION("aggregateDomain with static suffix") + { + AggDomainResult result; + std::string domain; + std::string static_suffix; + + domain = "biz.foo.bar.com"; + static_suffix = ".bar.com"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == ".foo.bar.com"); + CHECK(result.second == "biz.foo.bar.com"); + + domain = "biz.foo.bar.com"; + static_suffix = "bar.com"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == ".foo.bar.com"); + CHECK(result.second == "biz.foo.bar.com"); + + domain = "biz.foo.bar.com"; + static_suffix = "foo.bar.com"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == "biz.foo.bar.com"); + CHECK(result.second == ""); + + domain = "foo.bar.com."; + static_suffix = "biz.foo.bar.com"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == ".bar.com."); + CHECK(result.second == "foo.bar.com."); + + domain = "www.google.co.uk"; + static_suffix = ".co.uk"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == ".google.co.uk"); + CHECK(result.second == "www.google.co.uk"); + + domain = "www.google.co.uk"; + static_suffix = "google.co.uk"; + result = aggregateDomain(domain, static_suffix.size()); + CHECK(result.first == "www.google.co.uk"); + CHECK(result.second == ""); + } } diff --git a/src/handlers/dns/tests/test_dns_layer.cpp b/src/handlers/dns/tests/test_dns_layer.cpp index 8552d6dcc..314962433 100644 --- a/src/handlers/dns/tests/test_dns_layer.cpp +++ b/src/handlers/dns/tests/test_dns_layer.cpp @@ -374,6 +374,12 @@ TEST_CASE("DNS Filters: only_qname_suffix", "[pcap][dns]") CHECK(counters.REFUSED.value() == 0); CHECK(counters.NX.value() == 1); CHECK(counters.filtered.value() == 14); + + nlohmann::json j; + dns_handler.metrics()->bucket(0)->to_json(j); + + CHECK(j["top_qname2"][0]["name"].get().find("google.com") != std::string::npos); + CHECK(j["top_qname3"][0]["name"] == nullptr); } TEST_CASE("DNS groups", "[pcap][dns]")