From ceedc2d1ff8176bbfda01c0962e66d3fd78dbdea Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 26 May 2020 15:37:37 +0300 Subject: [PATCH] *(querylog): added offset/limit parameters Actually, this is a serious refactoring of the query log module. The rest API part is refactored, it's now more clear how the search is conducted. Split the logic into more files and added more tests. Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1559 --- changelog.config.js | 4 +- openapi/CHANGELOG.md | 7 + openapi/openapi.yaml | 15 +- querylog/json.go | 330 ++++++++++++++++++++++++++++++++++ querylog/qlog.go | 267 ++-------------------------- querylog/qlog_http.go | 186 +++++++++++-------- querylog/qlog_test.go | 255 +++++++++++++++++++++++++++ querylog/querylog_search.go | 343 ++++++++---------------------------- querylog/querylog_test.go | 176 ------------------ querylog/search_criteria.go | 139 +++++++++++++++ querylog/search_params.go | 57 ++++++ util/helpers.go | 10 ++ 12 files changed, 1013 insertions(+), 776 deletions(-) create mode 100644 querylog/json.go create mode 100644 querylog/qlog_test.go delete mode 100644 querylog/querylog_test.go create mode 100644 querylog/search_criteria.go create mode 100644 querylog/search_params.go diff --git a/changelog.config.js b/changelog.config.js index 95079477c8c..427807cde49 100644 --- a/changelog.config.js +++ b/changelog.config.js @@ -16,12 +16,14 @@ module.exports = { ], "scopes": [ "", + "ui", "global", "dnsfilter", "home", "dnsforward", "dhcpd", - "documentation" + "querylog", + "documentation", ], "types": { "+": { diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 193b69fca3d..0638dddf7dc 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -1,5 +1,12 @@ # AdGuard Home API Change Log +## v0.103: API changes + +### API: Get querylog: GET /control/querylog + +* Added optional "offset" and "limit" parameters + +We are still using "older_than" approach in AdGuard Home UI, but we realize that it's easier to use offset/limit so here is this option now. ## v0.102: API changes diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index a72bd816e51..304c8dc59aa 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -143,13 +143,26 @@ paths: tags: - log operationId: queryLog - summary: Get DNS server query log + summary: Get DNS server query log. parameters: - name: older_than in: query description: Filter by older than schema: type: string + - name: offset + in: query + description: + Specify the ranking number of the first item on the page. + Even though it is possible to use "offset" and "older_than", + we recommend choosing one of them and sticking to it. + schema: + type: integer + - name: limit + in: query + description: Limit the number of records to be returned + schema: + type: integer - name: filter_domain in: query description: Filter by domain name diff --git a/querylog/json.go b/querylog/json.go new file mode 100644 index 00000000000..7efed29dd47 --- /dev/null +++ b/querylog/json.go @@ -0,0 +1,330 @@ +package querylog + +import ( + "encoding/base64" + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// decodeLogEntry - decodes query log entry from a line +// nolint (gocyclo) +func decodeLogEntry(ent *logEntry, str string) { + var b bool + var i int + var err error + for { + k, v, t := readJSON(&str) + if t == jsonTErr { + break + } + switch k { + case "IP": + if len(ent.IP) == 0 { + ent.IP = v + } + case "T": + ent.Time, err = time.Parse(time.RFC3339, v) + + case "QH": + ent.QHost = v + case "QT": + ent.QType = v + case "QC": + ent.QClass = v + + case "Answer": + ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) + + case "IsFiltered": + b, err = strconv.ParseBool(v) + ent.Result.IsFiltered = b + case "Rule": + ent.Result.Rule = v + case "FilterID": + i, err = strconv.Atoi(v) + ent.Result.FilterID = int64(i) + case "Reason": + i, err = strconv.Atoi(v) + ent.Result.Reason = dnsfilter.Reason(i) + + case "Upstream": + ent.Upstream = v + case "Elapsed": + i, err = strconv.Atoi(v) + ent.Elapsed = time.Duration(i) + + // pre-v0.99.3 compatibility: + case "Question": + var qstr []byte + qstr, err = base64.StdEncoding.DecodeString(v) + if err != nil { + break + } + q := new(dns.Msg) + err = q.Unpack(qstr) + if err != nil { + break + } + ent.QHost = q.Question[0].Name + if len(ent.QHost) == 0 { + break + } + ent.QHost = ent.QHost[:len(ent.QHost)-1] + ent.QType = dns.TypeToString[q.Question[0].Qtype] + ent.QClass = dns.ClassToString[q.Question[0].Qclass] + case "Time": + ent.Time, err = time.Parse(time.RFC3339, v) + } + + if err != nil { + log.Debug("decodeLogEntry err: %s", err) + break + } + } +} + +// Get value from "key":"value" +func readJSONValue(s, name string) string { + i := strings.Index(s, "\""+name+"\":\"") + if i == -1 { + return "" + } + start := i + 1 + len(name) + 3 + i = strings.IndexByte(s[start:], '"') + if i == -1 { + return "" + } + end := start + i + return s[start:end] +} + +const ( + jsonTErr = iota + jsonTObj + jsonTStr + jsonTNum + jsonTBool +) + +// Parse JSON key-value pair +// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) +// Note the limitations: +// . doesn't support whitespace +// . doesn't support "null" +// . doesn't validate boolean or number +// . no proper handling of {} braces +// . no handling of [] brackets +// Return (key, value, type) +func readJSON(ps *string) (string, string, int32) { + s := *ps + k := "" + v := "" + t := int32(jsonTErr) + + q1 := strings.IndexByte(s, '"') + if q1 == -1 { + return k, v, t + } + q2 := strings.IndexByte(s[q1+1:], '"') + if q2 == -1 { + return k, v, t + } + k = s[q1+1 : q1+1+q2] + s = s[q1+1+q2+1:] + + if len(s) < 2 || s[0] != ':' { + return k, v, t + } + + if s[1] == '"' { + q2 = strings.IndexByte(s[2:], '"') + if q2 == -1 { + return k, v, t + } + v = s[2 : 2+q2] + t = jsonTStr + s = s[2+q2+1:] + + } else if s[1] == '{' { + t = jsonTObj + s = s[1+1:] + + } else { + sep := strings.IndexAny(s[1:], ",}") + if sep == -1 { + return k, v, t + } + v = s[1 : 1+sep] + if s[1] == 't' || s[1] == 'f' { + t = jsonTBool + } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { + t = jsonTNum + } + s = s[1+sep+1:] + } + + *ps = s + return k, v, t +} + +// Get Client IP address +func (l *queryLog) getClientIP(clientIP string) string { + if l.conf.AnonymizeClientIP { + ip := net.ParseIP(clientIP) + if ip != nil { + ip4 := ip.To4() + const AnonymizeClientIP4Mask = 24 + const AnonymizeClientIP6Mask = 112 + if ip4 != nil { + clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() + } else { + clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() + } + } + } + + return clientIP +} + +// entriesToJSON - converts log entries to JSON +func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) map[string]interface{} { + // init the response object + var data = []map[string]interface{}{} + + // the elements order is already reversed (from newer to older) + for i := 0; i < len(entries); i++ { + entry := entries[i] + jsonEntry := l.logEntryToJSONEntry(entry) + data = append(data, jsonEntry) + } + + var result = map[string]interface{}{} + result["oldest"] = "" + if !oldest.IsZero() { + result["oldest"] = oldest.Format(time.RFC3339Nano) + } + result["data"] = data + + return result +} + +func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} { + var msg *dns.Msg + + if len(entry.Answer) > 0 { + msg = new(dns.Msg) + if err := msg.Unpack(entry.Answer); err != nil { + log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) + msg = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339Nano), + "client": l.getClientIP(entry.IP), + } + jsonEntry["question"] = map[string]interface{}{ + "host": entry.QHost, + "type": entry.QType, + "class": entry.QClass, + } + + if msg != nil { + jsonEntry["status"] = dns.RcodeToString[msg.Rcode] + + opt := msg.IsEdns0() + dnssecOk := false + if opt != nil { + dnssecOk = opt.Do() + } + jsonEntry["answer_dnssec"] = dnssecOk + } + + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if len(entry.Result.ServiceName) != 0 { + jsonEntry["service_name"] = entry.Result.ServiceName + } + + answers := answerToMap(msg) + if answers != nil { + jsonEntry["answer"] = answers + } + + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + + return jsonEntry +} + +func answerToMap(a *dns.Msg) []map[string]interface{} { + if a == nil || len(a.Answer) == 0 { + return nil + } + + var answers = []map[string]interface{}{} + for _, k := range a.Answer { + header := k.Header() + answer := map[string]interface{}{ + "type": dns.TypeToString[header.Rrtype], + "ttl": header.Ttl, + } + // try most common record types + switch v := k.(type) { + case *dns.A: + answer["value"] = v.A.String() + case *dns.AAAA: + answer["value"] = v.AAAA.String() + case *dns.MX: + answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) + case *dns.CNAME: + answer["value"] = v.Target + case *dns.NS: + answer["value"] = v.Ns + case *dns.SPF: + answer["value"] = v.Txt + case *dns.TXT: + answer["value"] = v.Txt + case *dns.PTR: + answer["value"] = v.Ptr + case *dns.SOA: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) + case *dns.CAA: + answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) + case *dns.HINFO: + answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) + case *dns.RRSIG: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) + default: + // type unknown, marshall it as-is + answer["value"] = v + } + answers = append(answers, answer) + } + + return answers +} diff --git a/querylog/qlog.go b/querylog/qlog.go index 247bb5194c3..e85da3b9912 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -1,11 +1,8 @@ package querylog import ( - "fmt" - "net" "os" "path/filepath" - "strconv" "strings" "sync" "time" @@ -17,10 +14,6 @@ import ( const ( queryLogFileName = "querylog.json" // .gz added during compression - getDataLimit = 500 // GetData(): maximum log entries to return - - // maximum entries to parse when searching - maxSearchEntries = 50000 ) // queryLog is a structure that writes and reads the DNS query log @@ -36,6 +29,23 @@ type queryLog struct { fileWriteLock sync.Mutex } +// logEntry - represents a single log entry +type logEntry struct { + IP string `json:"IP"` // Client IP + Time time.Time `json:"T"` + + QHost string `json:"QH"` + QType string `json:"QT"` + QClass string `json:"QC"` + + Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + OrigAnswer []byte `json:",omitempty"` + + Result dnsfilter.Result + Elapsed time.Duration + Upstream string `json:",omitempty"` // if empty, means it was cached +} + // create a new instance of the query log func newQueryLog(conf Config) *queryLog { l := queryLog{} @@ -93,22 +103,6 @@ func (l *queryLog) clear() { log.Debug("Query log: cleared") } -type logEntry struct { - IP string `json:"IP"` - Time time.Time `json:"T"` - - QHost string `json:"QH"` - QType string `json:"QT"` - QClass string `json:"QC"` - - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - OrigAnswer []byte `json:",omitempty"` - - Result dnsfilter.Result - Elapsed time.Duration - Upstream string `json:",omitempty"` // if empty, means it was cached -} - func (l *queryLog) Add(params AddParams) { if !l.conf.Enabled { return @@ -173,230 +167,3 @@ func (l *queryLog) Add(params AddParams) { go l.flushLogBuffer(false) // nolint } } - -// Parameters for getData() -type getDataParams struct { - OlderThan time.Time // return entries that are older than this value - Domain string // filter by domain name in question - Client string // filter by client IP - QuestionType string // filter by question type - ResponseStatus responseStatusType // filter by response status - StrictMatchDomain bool // if Domain value must be matched strictly - StrictMatchClient bool // if Client value must be matched strictly -} - -// Response status -type responseStatusType int32 - -// Response status constants -const ( - responseStatusAll responseStatusType = iota + 1 - responseStatusFiltered -) - -// Gets log entries -func (l *queryLog) getData(params getDataParams) map[string]interface{} { - now := time.Now() - - if len(params.Client) != 0 && l.conf.AnonymizeClientIP { - params.Client = l.getClientIP(params.Client) - } - - // add from file - fileEntries, oldest, total := l.searchFiles(params) - - if params.OlderThan.IsZero() { - // In case if the timer is not precise (for instance, on Windows) - // We really want to get all records including those added just before the call - params.OlderThan = now.Add(time.Millisecond) - } - - // add from memory buffer - l.bufferLock.Lock() - total += len(l.buffer) - memoryEntries := make([]*logEntry, 0) - - // go through the buffer in the reverse order - // from NEWER to OLDER - for i := len(l.buffer) - 1; i >= 0; i-- { - entry := l.buffer[i] - - if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { - // Ignore entries newer than what was requested - continue - } - - if !matchesGetDataParams(entry, params) { - continue - } - - memoryEntries = append(memoryEntries, entry) - } - l.bufferLock.Unlock() - - // now let's get a unified collection - entries := append(memoryEntries, fileEntries...) - if len(entries) > getDataLimit { - // remove extra records - entries = entries[:getDataLimit] - } - if len(entries) == getDataLimit { - // change the "oldest" value here. - // we cannot use the "oldest" we got from "searchFiles" anymore - // because after adding in-memory records and removing extra records - // the situation has changed - oldest = entries[len(entries)-1].Time - } - - // init the response object - var data = []map[string]interface{}{} - - // the elements order is already reversed (from newer to older) - for i := 0; i < len(entries); i++ { - entry := entries[i] - jsonEntry := l.logEntryToJSONEntry(entry) - data = append(data, jsonEntry) - } - - log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s", - len(entries), total, params.OlderThan, time.Since(now)) - - var result = map[string]interface{}{} - result["oldest"] = "" - if !oldest.IsZero() { - result["oldest"] = oldest.Format(time.RFC3339Nano) - } - result["data"] = data - return result -} - -// Get Client IP address -func (l *queryLog) getClientIP(clientIP string) string { - if l.conf.AnonymizeClientIP { - ip := net.ParseIP(clientIP) - if ip != nil { - ip4 := ip.To4() - const AnonymizeClientIP4Mask = 24 - const AnonymizeClientIP6Mask = 112 - if ip4 != nil { - clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() - } else { - clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() - } - } - } - - return clientIP -} - -func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} { - var msg *dns.Msg - - if len(entry.Answer) > 0 { - msg = new(dns.Msg) - if err := msg.Unpack(entry.Answer); err != nil { - log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) - msg = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339Nano), - "client": l.getClientIP(entry.IP), - } - jsonEntry["question"] = map[string]interface{}{ - "host": entry.QHost, - "type": entry.QType, - "class": entry.QClass, - } - - if msg != nil { - jsonEntry["status"] = dns.RcodeToString[msg.Rcode] - - opt := msg.IsEdns0() - dnssecOk := false - if opt != nil { - dnssecOk = opt.Do() - } - jsonEntry["answer_dnssec"] = dnssecOk - } - - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if len(entry.Result.ServiceName) != 0 { - jsonEntry["service_name"] = entry.Result.ServiceName - } - - answers := answerToMap(msg) - if answers != nil { - jsonEntry["answer"] = answers - } - - if len(entry.OrigAnswer) != 0 { - a := new(dns.Msg) - err := a.Unpack(entry.OrigAnswer) - if err == nil { - answers = answerToMap(a) - if answers != nil { - jsonEntry["original_answer"] = answers - } - } else { - log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) - } - } - - return jsonEntry -} - -func answerToMap(a *dns.Msg) []map[string]interface{} { - if a == nil || len(a.Answer) == 0 { - return nil - } - - var answers = []map[string]interface{}{} - for _, k := range a.Answer { - header := k.Header() - answer := map[string]interface{}{ - "type": dns.TypeToString[header.Rrtype], - "ttl": header.Ttl, - } - // try most common record types - switch v := k.(type) { - case *dns.A: - answer["value"] = v.A.String() - case *dns.AAAA: - answer["value"] = v.AAAA.String() - case *dns.MX: - answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) - case *dns.CNAME: - answer["value"] = v.Target - case *dns.NS: - answer["value"] = v.Ns - case *dns.SPF: - answer["value"] = v.Txt - case *dns.TXT: - answer["value"] = v.Txt - case *dns.PTR: - answer["value"] = v.Ptr - case *dns.SOA: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) - case *dns.CAA: - answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) - case *dns.HINFO: - answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) - case *dns.RRSIG: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) - default: - // type unknown, marshall it as-is - answer["value"] = v - } - answers = append(answers, answer) - } - - return answers -} diff --git a/querylog/qlog_http.go b/querylog/qlog_http.go index fae8dba6313..19caa35cee5 100644 --- a/querylog/qlog_http.go +++ b/querylog/qlog_http.go @@ -4,13 +4,30 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "strconv" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) +type qlogConfig struct { + Enabled bool `json:"enabled"` + Interval uint32 `json:"interval"` + AnonymizeClientIP bool `json:"anonymize_client_ip"` +} + +// Register web handlers +func (l *queryLog) initWeb() { + l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog) + l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo) + l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear) + l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig) +} + func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) @@ -19,74 +36,18 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string, http.Error(w, text, code) } -type request struct { - olderThan string - filterDomain string - filterClient string - filterQuestionType string - filterResponseStatus string -} - -// "value" -> value, return TRUE -func getDoubleQuotesEnclosedValue(s *string) bool { - t := *s - if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' { - *s = t[1 : len(t)-1] - return true - } - return false -} - func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { - var err error - req := request{} - q := r.URL.Query() - req.olderThan = q.Get("older_than") - req.filterDomain = q.Get("filter_domain") - req.filterClient = q.Get("filter_client") - req.filterQuestionType = q.Get("filter_question_type") - req.filterResponseStatus = q.Get("filter_response_status") - - params := getDataParams{ - Domain: req.filterDomain, - Client: req.filterClient, - ResponseStatus: responseStatusAll, - } - if len(req.olderThan) != 0 { - params.OlderThan, err = time.Parse(time.RFC3339Nano, req.olderThan) - if err != nil { - httpError(r, w, http.StatusBadRequest, "invalid time stamp: %s", err) - return - } - } - - if getDoubleQuotesEnclosedValue(¶ms.Domain) { - params.StrictMatchDomain = true - } - if getDoubleQuotesEnclosedValue(¶ms.Client) { - params.StrictMatchClient = true - } - - if len(req.filterQuestionType) != 0 { - _, ok := dns.StringToType[req.filterQuestionType] - if !ok { - httpError(r, w, http.StatusBadRequest, "invalid question_type") - return - } - params.QuestionType = req.filterQuestionType + params, err := l.parseSearchParams(r) + if err != nil { + httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err) + return } - if len(req.filterResponseStatus) != 0 { - switch req.filterResponseStatus { - case "filtered": - params.ResponseStatus = responseStatusFiltered - default: - httpError(r, w, http.StatusBadRequest, "invalid response_status") - return - } - } + // search for the log entries + entries, oldest := l.search(params) - data := l.getData(params) + // convert log entries to JSON + var data = l.entriesToJSON(entries, oldest) jsonVal, err := json.Marshal(data) if err != nil { @@ -101,16 +62,10 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { } } -func (l *queryLog) handleQueryLogClear(w http.ResponseWriter, r *http.Request) { +func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { l.clear() } -type qlogConfig struct { - Enabled bool `json:"enabled"` - Interval uint32 `json:"interval"` - AnonymizeClientIP bool `json:"anonymize_client_ip"` -} - // Get configuration func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { resp := qlogConfig{} @@ -162,10 +117,85 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) l.conf.ConfigModified() } -// Register web handlers -func (l *queryLog) initWeb() { - l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog) - l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo) - l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear) - l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig) +// "value" -> value, return TRUE +func getDoubleQuotesEnclosedValue(s *string) bool { + t := *s + if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' { + *s = t[1 : len(t)-1] + return true + } + return false +} + +// parseSearchCriteria - parses "searchCriteria" from the specified query parameter +func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) { + val := q.Get(name) + if len(val) == 0 { + return false, searchCriteria{}, nil + } + + c := searchCriteria{ + criteriaType: ct, + value: val, + } + if getDoubleQuotesEnclosedValue(&c.value) { + c.strict = true + } + + if ct == ctClient && l.conf.AnonymizeClientIP { + c.value = l.getClientIP(c.value) + } + + if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) { + return false, c, fmt.Errorf("invalid value %s", c.value) + } + + return true, c, nil +} + +// parseSearchParams - parses "searchParams" from the HTTP request's query string +func (l *queryLog) parseSearchParams(r *http.Request) (*searchParams, error) { + p := newSearchParams() + + var err error + q := r.URL.Query() + olderThan := q.Get("older_than") + if len(olderThan) != 0 { + p.olderThan, err = time.Parse(time.RFC3339Nano, olderThan) + if err != nil { + return nil, err + } + } + + if limit, err := strconv.ParseInt(q.Get("limit"), 10, 64); err == nil { + p.limit = int(limit) + + // If limit or offset are specified explicitly, we should change the default behavior + // and scan all log records until we found enough log entries + p.maxFileScanEntries = 0 + } + if offset, err := strconv.ParseInt(q.Get("offset"), 10, 64); err == nil { + p.offset = int(offset) + p.maxFileScanEntries = 0 + } + + paramNames := map[string]criteriaType{ + "filter_domain": ctDomain, + "filter_client": ctClient, + "filter_question_type": ctQuestionType, + "filter_response_status": ctFilteringStatus, + } + + for k, v := range paramNames { + ok, c, err := l.parseSearchCriteria(q, k, v) + if err != nil { + return nil, err + } + + if ok { + p.searchCriteria = append(p.searchCriteria, c) + } + } + + return p, nil } diff --git a/querylog/qlog_test.go b/querylog/qlog_test.go new file mode 100644 index 00000000000..68c346f581d --- /dev/null +++ b/querylog/qlog_test.go @@ -0,0 +1,255 @@ +package querylog + +import ( + "net" + "os" + "testing" + + "github.com/AdguardTeam/dnsproxy/proxyutil" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func prepareTestDir() string { + const dir = "./agh-test" + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, 0755) + return dir +} + +// Check adding and loading (with filtering) entries from disk and memory +func TestQueryLog(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add disk entries + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + // write to disk (first file) + _ = l.flushLogBuffer(true) + // start writing to the second file + _ = l.rotate() + // add disk entries + addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") + // write to disk + _ = l.flushLogBuffer(true) + // add memory entries + addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") + addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") + + // get all entries + params := newSearchParams() + entries, _ := l.search(params) + assert.Equal(t, 4, len(entries)) + assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") + assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") + + // search by domain (strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctDomain, + strict: true, + value: "test.example.org", + }) + entries, _ = l.search(params) + assert.Equal(t, 1, len(entries)) + assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") + + // search by domain (not strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctDomain, + strict: false, + value: "example.org", + }) + entries, _ = l.search(params) + assert.Equal(t, 3, len(entries)) + assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") + + // search by client IP (strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctClient, + strict: true, + value: "2.2.2.2", + }) + entries, _ = l.search(params) + assert.Equal(t, 1, len(entries)) + assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") + + // search by client IP (part of) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctClient, + strict: false, + value: "2.2.2", + }) + entries, _ = l.search(params) + assert.Equal(t, 4, len(entries)) + assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") + assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") +} + +func TestQueryLogOffsetLimit(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add 10 entries to the log + for i := 0; i < 10; i++ { + addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1") + } + // write them to disk (first file) + _ = l.flushLogBuffer(true) + // add 10 more entries to the log (memory) + for i := 0; i < 10; i++ { + addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1") + } + + // First page + params := newSearchParams() + params.offset = 0 + params.limit = 10 + entries, _ := l.search(params) + assert.Equal(t, 10, len(entries)) + assert.Equal(t, entries[0].QHost, "first.example.org") + assert.Equal(t, entries[9].QHost, "first.example.org") + + // Second page + params.offset = 10 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 10, len(entries)) + assert.Equal(t, entries[0].QHost, "second.example.org") + assert.Equal(t, entries[9].QHost, "second.example.org") + + // Second and a half page + params.offset = 15 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 5, len(entries)) + assert.Equal(t, entries[0].QHost, "second.example.org") + assert.Equal(t, entries[4].QHost, "second.example.org") + + // Third page + params.offset = 20 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 0, len(entries)) +} + +func TestQueryLogMaxFileScanEntries(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add 10 entries to the log + for i := 0; i < 10; i++ { + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + } + // write them to disk (first file) + _ = l.flushLogBuffer(true) + + params := newSearchParams() + params.maxFileScanEntries = 5 // do not scan more than 5 records + entries, _ := l.search(params) + assert.Equal(t, 5, len(entries)) + + params.maxFileScanEntries = 0 // disable the limit + entries, _ = l.search(params) + assert.Equal(t, 10, len(entries)) +} + +func TestJSON(t *testing.T) { + s := ` + {"keystr":"val","obj":{"keybool":true,"keyint":123456}} + ` + k, v, jtype := readJSON(&s) + assert.Equal(t, jtype, int32(jsonTStr)) + assert.Equal(t, "keystr", k) + assert.Equal(t, "val", v) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTObj)) + assert.Equal(t, "obj", k) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTBool)) + assert.Equal(t, "keybool", k) + assert.Equal(t, "true", v) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTNum)) + assert.Equal(t, "keyint", k) + assert.Equal(t, "123456", v) + + k, v, jtype = readJSON(&s) + assert.True(t, jtype == jsonTErr) +} + +func addEntry(l *queryLog, host, answerStr, client string) { + q := dns.Msg{} + q.Question = append(q.Question, dns.Question{ + Name: host + ".", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }) + + a := dns.Msg{} + a.Question = append(a.Question, q.Question[0]) + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: q.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + } + answer.A = net.ParseIP(answerStr) + a.Answer = append(a.Answer, answer) + res := dnsfilter.Result{} + params := AddParams{ + Question: &q, + Answer: &a, + Result: &res, + ClientIP: net.ParseIP(client), + Upstream: "upstream", + } + l.Add(params) +} + +func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool { + assert.Equal(t, host, entry.QHost) + assert.Equal(t, client, entry.IP) + assert.Equal(t, "A", entry.QType) + assert.Equal(t, "IN", entry.QClass) + + msg := new(dns.Msg) + assert.Nil(t, msg.Unpack(entry.Answer)) + assert.Equal(t, 1, len(msg.Answer)) + ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) + assert.NotNil(t, ip) + assert.Equal(t, answer, ip.String()) + return true +} diff --git a/querylog/querylog_search.go b/querylog/querylog_search.go index f9493af9531..eda1f92d869 100644 --- a/querylog/querylog_search.go +++ b/querylog/querylog_search.go @@ -1,18 +1,72 @@ package querylog import ( - "encoding/base64" "io" - "strconv" - "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) +// search - searches log entries in the query log using specified parameters +// returns the list of entries found + time of the oldest entry +func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) { + now := time.Now() + + if params.limit == 0 { + return []*logEntry{}, time.Time{} + } + + // add from file + fileEntries, oldest, total := l.searchFiles(params) + + // add from memory buffer + l.bufferLock.Lock() + total += len(l.buffer) + memoryEntries := make([]*logEntry, 0) + + // go through the buffer in the reverse order + // from NEWER to OLDER + for i := len(l.buffer) - 1; i >= 0; i-- { + entry := l.buffer[i] + if !params.match(entry) { + continue + } + memoryEntries = append(memoryEntries, entry) + } + l.bufferLock.Unlock() + + // limits + totalLimit := params.offset + params.limit + + // now let's get a unified collection + entries := append(memoryEntries, fileEntries...) + if len(entries) > totalLimit { + // remove extra records + entries = entries[:totalLimit] + } + if params.offset > 0 { + if len(entries) > params.offset { + entries = entries[params.offset:] + } else { + entries = make([]*logEntry, 0) + oldest = time.Time{} + } + } + if len(entries) == totalLimit { + // change the "oldest" value here. + // we cannot use the "oldest" we got from "searchFiles" anymore + // because after adding in-memory records and removing extra records + // the situation has changed + oldest = entries[len(entries)-1].Time + } + + log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s", + len(entries), total, params.olderThan, time.Since(now)) + + return entries, oldest +} + // searchFiles reads log entries from all log files and applies the specified search criteria. // IMPORTANT: this method does not scan more than "maxSearchEntries" so you // may need to call it many times. @@ -21,7 +75,7 @@ import ( // * an array of log entries that we have read // * time of the oldest processed entry (even if it was discarded) // * total number of processed entries (including discarded). -func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { +func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, int) { entries := make([]*logEntry, 0) oldest := time.Time{} @@ -32,10 +86,10 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in } defer r.Close() - if params.OlderThan.IsZero() { + if params.olderThan.IsZero() { err = r.SeekStart() } else { - err = r.Seek(params.OlderThan.UnixNano()) + err = r.Seek(params.olderThan.UnixNano()) if err == nil { // Read to the next record right away // The one that was specified in the "oldest" param is not needed, @@ -45,14 +99,17 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in } if err != nil { - log.Debug("Cannot Seek() to %v: %v", params.OlderThan, err) + log.Debug("Cannot Seek() to %v: %v", params.olderThan, err) return entries, oldest, 0 } + totalLimit := params.offset + params.limit total := 0 oldestNano := int64(0) - // Do not scan more than 50k at once - for total <= maxSearchEntries { + // By default, we do not scan more than "maxFileScanEntries" at once + // The idea is to make search calls faster so that the UI could handle it and show something + // This behavior can be overridden if "maxFileScanEntries" is set to 0 + for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { entry, ts, err := l.readNextEntry(r, params) if err == io.EOF { @@ -65,8 +122,8 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in if entry != nil { entries = append(entries, entry) - if len(entries) == getDataLimit { - // Do not read more than "getDataLimit" records at once + if len(entries) == totalLimit { + // Do not read more than "totalLimit" records at once break } } @@ -82,7 +139,7 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in // * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) // * timestamp of the processed log entry // * error if we can't read anymore -func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { +func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry, int64, error) { line, err := r.ReadNext() if err != nil { return nil, 0, err @@ -92,7 +149,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry timestamp := readQLogTimestamp(line) // Quick check without deserializing log entry - if !quickMatchesGetDataParams(line, params) { + if !params.quickMatch(line) { return nil, timestamp, nil } @@ -100,7 +157,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry decodeLogEntry(&entry, line) // Full check of the deserialized log entry - if !matchesGetDataParams(&entry, params) { + if !params.match(&entry) { return nil, timestamp, nil } @@ -120,257 +177,3 @@ func (l *queryLog) openReader() (*QLogReader, error) { return NewQLogReader(files) } - -// quickMatchesGetDataParams - quickly checks if the line matches getDataParams -// this method does not guarantee anything and the reason is to do a quick check -// without deserializing anything -func quickMatchesGetDataParams(line string, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered { - boolVal, ok := readJSONBool(line, "IsFiltered") - if !ok || !boolVal { - return false - } - } - - if len(params.Domain) != 0 { - val := readJSONValue(line, "QH") - if len(val) == 0 { - return false - } - - if (params.StrictMatchDomain && val != params.Domain) || - (!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) { - return false - } - } - - if len(params.QuestionType) != 0 { - val := readJSONValue(line, "QT") - if val != params.QuestionType { - return false - } - } - - if len(params.Client) != 0 { - val := readJSONValue(line, "IP") - if len(val) == 0 { - log.Debug("QueryLog: failed to decodeLogEntry") - return false - } - - if (params.StrictMatchClient && val != params.Client) || - (!params.StrictMatchClient && strings.Index(val, params.Client) == -1) { - return false - } - } - - return true -} - -// matchesGetDataParams - returns true if the entry matches the search parameters -func matchesGetDataParams(entry *logEntry, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { - return false - } - - if len(params.QuestionType) != 0 { - if entry.QType != params.QuestionType { - return false - } - } - - if len(params.Domain) != 0 { - if (params.StrictMatchDomain && entry.QHost != params.Domain) || - (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { - return false - } - } - - if len(params.Client) != 0 { - if (params.StrictMatchClient && entry.IP != params.Client) || - (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { - return false - } - } - - return true -} - -// decodeLogEntry - decodes query log entry from a line -// nolint (gocyclo) -func decodeLogEntry(ent *logEntry, str string) { - var b bool - var i int - var err error - for { - k, v, t := readJSON(&str) - if t == jsonTErr { - break - } - switch k { - case "IP": - if len(ent.IP) == 0 { - ent.IP = v - } - case "T": - ent.Time, err = time.Parse(time.RFC3339, v) - - case "QH": - ent.QHost = v - case "QT": - ent.QType = v - case "QC": - ent.QClass = v - - case "Answer": - ent.Answer, err = base64.StdEncoding.DecodeString(v) - case "OrigAnswer": - ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) - - case "IsFiltered": - b, err = strconv.ParseBool(v) - ent.Result.IsFiltered = b - case "Rule": - ent.Result.Rule = v - case "FilterID": - i, err = strconv.Atoi(v) - ent.Result.FilterID = int64(i) - case "Reason": - i, err = strconv.Atoi(v) - ent.Result.Reason = dnsfilter.Reason(i) - - case "Upstream": - ent.Upstream = v - case "Elapsed": - i, err = strconv.Atoi(v) - ent.Elapsed = time.Duration(i) - - // pre-v0.99.3 compatibility: - case "Question": - var qstr []byte - qstr, err = base64.StdEncoding.DecodeString(v) - if err != nil { - break - } - q := new(dns.Msg) - err = q.Unpack(qstr) - if err != nil { - break - } - ent.QHost = q.Question[0].Name - if len(ent.QHost) == 0 { - break - } - ent.QHost = ent.QHost[:len(ent.QHost)-1] - ent.QType = dns.TypeToString[q.Question[0].Qtype] - ent.QClass = dns.ClassToString[q.Question[0].Qclass] - case "Time": - ent.Time, err = time.Parse(time.RFC3339, v) - } - - if err != nil { - log.Debug("decodeLogEntry err: %s", err) - break - } - } -} - -// Get bool value from "key":bool -func readJSONBool(s, name string) (bool, bool) { - i := strings.Index(s, "\""+name+"\":") - if i == -1 { - return false, false - } - start := i + 1 + len(name) + 2 - b := false - if strings.HasPrefix(s[start:], "true") { - b = true - } else if !strings.HasPrefix(s[start:], "false") { - return false, false - } - return b, true -} - -// Get value from "key":"value" -func readJSONValue(s, name string) string { - i := strings.Index(s, "\""+name+"\":\"") - if i == -1 { - return "" - } - start := i + 1 + len(name) + 3 - i = strings.IndexByte(s[start:], '"') - if i == -1 { - return "" - } - end := start + i - return s[start:end] -} - -const ( - jsonTErr = iota - jsonTObj - jsonTStr - jsonTNum - jsonTBool -) - -// Parse JSON key-value pair -// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) -// Note the limitations: -// . doesn't support whitespace -// . doesn't support "null" -// . doesn't validate boolean or number -// . no proper handling of {} braces -// . no handling of [] brackets -// Return (key, value, type) -func readJSON(ps *string) (string, string, int32) { - s := *ps - k := "" - v := "" - t := int32(jsonTErr) - - q1 := strings.IndexByte(s, '"') - if q1 == -1 { - return k, v, t - } - q2 := strings.IndexByte(s[q1+1:], '"') - if q2 == -1 { - return k, v, t - } - k = s[q1+1 : q1+1+q2] - s = s[q1+1+q2+1:] - - if len(s) < 2 || s[0] != ':' { - return k, v, t - } - - if s[1] == '"' { - q2 = strings.IndexByte(s[2:], '"') - if q2 == -1 { - return k, v, t - } - v = s[2 : 2+q2] - t = jsonTStr - s = s[2+q2+1:] - - } else if s[1] == '{' { - t = jsonTObj - s = s[1+1:] - - } else { - sep := strings.IndexAny(s[1:], ",}") - if sep == -1 { - return k, v, t - } - v = s[1 : 1+sep] - if s[1] == 't' || s[1] == 'f' { - t = jsonTBool - } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { - t = jsonTNum - } - s = s[1+sep+1:] - } - - *ps = s - return k, v, t -} diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go deleted file mode 100644 index 06de4101127..00000000000 --- a/querylog/querylog_test.go +++ /dev/null @@ -1,176 +0,0 @@ -package querylog - -import ( - "net" - "os" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" -) - -func prepareTestDir() string { - const dir = "./agh-test" - _ = os.RemoveAll(dir) - _ = os.MkdirAll(dir, 0755) - return dir -} - -// Check adding and loading (with filtering) entries from disk and memory -func TestQueryLog(t *testing.T) { - conf := Config{ - Enabled: true, - Interval: 1, - MemSize: 100, - } - conf.BaseDir = prepareTestDir() - defer func() { _ = os.RemoveAll(conf.BaseDir) }() - l := newQueryLog(conf) - - // add disk entries - addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") - // write to disk (first file) - _ = l.flushLogBuffer(true) - // start writing to the second file - _ = l.rotate() - // add disk entries - addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") - // write to disk - _ = l.flushLogBuffer(true) - // add memory entries - addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") - addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") - - // get all entries - params := getDataParams{ - OlderThan: time.Time{}, - } - d := l.getData(params) - mdata := d["data"].([]map[string]interface{}) - assert.Equal(t, 4, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) - assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) - - // search by domain (strict) - params = getDataParams{ - OlderThan: time.Time{}, - Domain: "test.example.org", - StrictMatchDomain: true, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - - // search by domain (not strict) - params = getDataParams{ - OlderThan: time.Time{}, - Domain: "example.org", - StrictMatchDomain: false, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 3, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) - - // search by client IP (strict) - params = getDataParams{ - OlderThan: time.Time{}, - Client: "2.2.2.2", - StrictMatchClient: true, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 1, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2")) - - // search by client IP (part of) - params = getDataParams{ - OlderThan: time.Time{}, - Client: "2.2.2", - StrictMatchClient: false, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 4, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) - assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) -} - -func addEntry(l *queryLog, host, answerStr, client string) { - q := dns.Msg{} - q.Question = append(q.Question, dns.Question{ - Name: host + ".", - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }) - - a := dns.Msg{} - a.Question = append(a.Question, q.Question[0]) - answer := new(dns.A) - answer.Hdr = dns.RR_Header{ - Name: q.Question[0].Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - } - answer.A = net.ParseIP(answerStr) - a.Answer = append(a.Answer, answer) - res := dnsfilter.Result{} - params := AddParams{ - Question: &q, - Answer: &a, - Result: &res, - ClientIP: net.ParseIP(client), - Upstream: "upstream", - } - l.Add(params) -} - -func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool { - mq := m["question"].(map[string]interface{}) - ma := m["answer"].([]map[string]interface{}) - ma0 := ma[0] - if !assert.Equal(t, host, mq["host"].(string)) || - !assert.Equal(t, "IN", mq["class"].(string)) || - !assert.Equal(t, "A", mq["type"].(string)) || - !assert.Equal(t, answer, ma0["value"].(string)) || - !assert.Equal(t, client, m["client"].(string)) { - return false - } - return true -} - -func TestJSON(t *testing.T) { - s := ` - {"keystr":"val","obj":{"keybool":true,"keyint":123456}} - ` - k, v, jtype := readJSON(&s) - assert.Equal(t, jtype, int32(jsonTStr)) - assert.Equal(t, "keystr", k) - assert.Equal(t, "val", v) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTObj)) - assert.Equal(t, "obj", k) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTBool)) - assert.Equal(t, "keybool", k) - assert.Equal(t, "true", v) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTNum)) - assert.Equal(t, "keyint", k) - assert.Equal(t, "123456", v) - - k, v, jtype = readJSON(&s) - assert.True(t, jtype == jsonTErr) -} diff --git a/querylog/search_criteria.go b/querylog/search_criteria.go new file mode 100644 index 00000000000..b2ba63f6cce --- /dev/null +++ b/querylog/search_criteria.go @@ -0,0 +1,139 @@ +package querylog + +import ( + "strings" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" +) + +type criteriaType int + +const ( + ctDomain criteriaType = iota // domain name + ctClient // client IP address + ctQuestionType // question type + ctFilteringStatus // filtering status +) + +const ( + filteringStatusAll = "all" + filteringStatusFiltered = "filtered" // all kinds of filtering + + filteringStatusBlocked = "blocked" // blocked or blocked service + filteringStatusBlockedSafebrowsing = "blocked_safebrowsing" // blocked by safebrowsing + filteringStatusBlockedParental = "blocked_parental" // blocked by parental control + filteringStatusWhitelisted = "whitelisted" // whitelisted + filteringStatusRewritten = "rewritten" // all kinds of rewrites + filteringStatusSafeSearch = "safe_search" // enforced safe search +) + +// filteringStatusValues -- array with all possible filteringStatus values +var filteringStatusValues = []string{ + filteringStatusAll, filteringStatusFiltered, filteringStatusBlocked, + filteringStatusBlockedSafebrowsing, filteringStatusBlockedParental, + filteringStatusWhitelisted, filteringStatusRewritten, filteringStatusSafeSearch, +} + +// searchCriteria - every search request may contain a list of different search criteria +// we use each of them to match the query +type searchCriteria struct { + criteriaType criteriaType // type of the criteria + strict bool // should we strictly match (equality) or not (indexOf) + value string // search criteria value +} + +// quickMatch - quickly checks if the log entry matches this search criteria +// the reason is to do it as quickly as possible without de-serializing the entry +func (c *searchCriteria) quickMatch(line string) bool { + // note that we do this only for a limited set of criteria + + switch c.criteriaType { + case ctDomain: + return c.quickMatchJSONValue(line, "QH") + case ctClient: + return c.quickMatchJSONValue(line, "IP") + case ctQuestionType: + return c.quickMatchJSONValue(line, "QT") + default: + return true + } +} + +// quickMatchJSONValue - helper used by quickMatch +func (c *searchCriteria) quickMatchJSONValue(line string, propertyName string) bool { + val := readJSONValue(line, propertyName) + if len(val) == 0 { + return false + } + + if c.strict && c.value == val { + return true + } + if !c.strict && strings.Contains(val, c.value) { + return true + } + + return false +} + +// match - checks if the log entry matches this search criteria +// nolint (gocyclo) +func (c *searchCriteria) match(entry *logEntry) bool { + switch c.criteriaType { + case ctDomain: + if c.strict && entry.QHost == c.value { + return true + } + if !c.strict && strings.Contains(entry.QHost, c.value) { + return true + } + return false + case ctClient: + if c.strict && entry.IP == c.value { + return true + } + if !c.strict && strings.Contains(entry.IP, c.value) { + return true + } + return false + case ctQuestionType: + if c.strict && entry.QType == c.value { + return true + } + if !c.strict && strings.Contains(entry.QType, c.value) { + return true + } + case ctFilteringStatus: + res := entry.Result + + switch c.value { + case filteringStatusAll: + return true + case filteringStatusFiltered: + return res.IsFiltered + case filteringStatusBlocked: + return res.IsFiltered && + (res.Reason == dnsfilter.FilteredBlackList || + res.Reason == dnsfilter.FilteredBlockedService) + case filteringStatusBlockedParental: + return res.IsFiltered && res.Reason == dnsfilter.FilteredParental + case filteringStatusBlockedSafebrowsing: + return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeBrowsing + case filteringStatusWhitelisted: + return res.IsFiltered && res.Reason == dnsfilter.NotFilteredWhiteList + case filteringStatusRewritten: + return res.IsFiltered && + (res.Reason == dnsfilter.ReasonRewrite || + res.Reason == dnsfilter.RewriteEtcHosts) + case filteringStatusSafeSearch: + return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch + default: + return false + } + + default: + return false + } + + return false +} diff --git a/querylog/search_params.go b/querylog/search_params.go new file mode 100644 index 00000000000..da083e452dc --- /dev/null +++ b/querylog/search_params.go @@ -0,0 +1,57 @@ +package querylog + +import "time" + +// searchParams represent the search query sent by the client +type searchParams struct { + // searchCriteria - list of search criteria that we use to get filter results + searchCriteria []searchCriteria + + // olderThen - return entries that are older than this value + // if not set - disregard it and return any value + olderThan time.Time + + offset int // offset for the search + limit int // limit the number of records returned + maxFileScanEntries int // maximum log entries to scan in query log files. if 0 - no limit +} + +// newSearchParams - creates an empty instance of searchParams +func newSearchParams() *searchParams { + return &searchParams{ + // default max log entries to return + limit: 500, + + // by default, we scan up to 50k entries at once + maxFileScanEntries: 50000, + } +} + +// quickMatchesGetDataParams - quickly checks if the line matches the searchParams +// this method does not guarantee anything and the reason is to do a quick check +// without deserializing anything +func (s *searchParams) quickMatch(line string) bool { + for _, c := range s.searchCriteria { + if !c.quickMatch(line) { + return false + } + } + + return true +} + +// match - checks if the logEntry matches the searchParams +func (s *searchParams) match(entry *logEntry) bool { + if !s.olderThan.IsZero() && entry.Time.UnixNano() >= s.olderThan.UnixNano() { + // Ignore entries newer than what was requested + return false + } + + for _, c := range s.searchCriteria { + if !c.match(entry) { + return false + } + } + + return true +} diff --git a/util/helpers.go b/util/helpers.go index 27ac4d710ab..b32ac456f17 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,6 +10,16 @@ import ( "strings" ) +// ContainsString checks if "v" is in the array "arr" +func ContainsString(arr []string, v string) bool { + for _, i := range arr { + if i == v { + return true + } + } + return false +} + // fileExists returns TRUE if file exists func FileExists(fn string) bool { _, err := os.Stat(fn)