From fc31a0599e667a2aafccd8b4fba03de314136edb Mon Sep 17 00:00:00 2001 From: Stefano Gabryel Date: Sun, 28 Jul 2019 08:58:36 +0200 Subject: [PATCH] Introducing sigint handling --- functional-tests.sh | 6 +- pkg/cmd/scan.go | 26 ++++-- pkg/cmd/scan_integration_test.go | 64 ++++++++++++++- pkg/scan/producer/reproducer.go | 2 +- pkg/scan/producer/reproducer_test.go | 33 ++++---- pkg/scan/result_logger.go | 4 +- pkg/scan/result_logger_test.go | 12 ++- pkg/scan/result_summarizer.go | 18 ++--- pkg/scan/result_summarizer_test.go | 113 ++++++++++++++++----------- pkg/scan/scanner.go | 21 +++-- 10 files changed, 205 insertions(+), 94 deletions(-) diff --git a/functional-tests.sh b/functional-tests.sh index 9423d35..7902e54 100755 --- a/functional-tests.sh +++ b/functional-tests.sh @@ -65,9 +65,9 @@ SCAN_RESULT=$(./dist/dirstalk scan -d resources/tests/dictionary.txt http://loca assert_contains "$SCAN_RESULT" "/index" "result expected when performing scan" assert_contains "$SCAN_RESULT" "/index/home" "result expected when performing scan" assert_contains "$SCAN_RESULT" "8 requests made, 3 results found" "a recap was expected when performing a scan" -assert_contains "$SCAN_RESULT" "├── index" "a recap was expected when performing a scan" -assert_contains "$SCAN_RESULT" "│ └── home" "a recap was expected when performing a scan" -assert_contains "$SCAN_RESULT" "└── home" "a recap was expected when performing a scan" +assert_contains "$SCAN_RESULT" "├── home" "a recap was expected when performing a scan" +assert_contains "$SCAN_RESULT" "└── index" "a recap was expected when performing a scan" +assert_contains "$SCAN_RESULT" " └── home" "a recap was expected when performing a scan" assert_not_contains "$SCAN_RESULT" "error" "no error is expected for a successful scan" diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index 729923d..c21951f 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "net/url" + "os" + "os/signal" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -179,14 +181,28 @@ func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error { resultLogger := scan.NewResultLogger(logger) summarizer := scan.NewResultSummarizer(logger.Out) - for result := range s.Scan(u, cnf.Threads) { - resultLogger.Log(result) - summarizer.Add(result) + + osSigint := make(chan os.Signal, 1) + signal.Notify(osSigint, os.Interrupt) + + finishFunc := func() { + summarizer.Summarize() + logger.Info("Finished scan") } - summarizer.Summarize() + for result := range s.Scan(u, cnf.Threads) { + select { + case <-osSigint: + logger.Info("Received sigint, terminating...") + finishFunc() + return nil + default: + resultLogger.Log(result) + summarizer.Add(result) + } + } - logger.Info("Finished scan") + finishFunc() return nil } diff --git a/pkg/cmd/scan_integration_test.go b/pkg/cmd/scan_integration_test.go index c9fbbe7..fe36e6d 100644 --- a/pkg/cmd/scan_integration_test.go +++ b/pkg/cmd/scan_integration_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "sync" + "syscall" "testing" "time" @@ -30,7 +31,6 @@ func TestScanCommand(t *testing.T) { } w.WriteHeader(http.StatusNotFound) - }), ) defer testServer.Close() @@ -70,6 +70,68 @@ func TestScanCommand(t *testing.T) { assert.Equal(t, expectedRequests, requestsMap) } +func TestScanWithNoTargetShouldErr(t *testing.T) { + logger, _ := test.NewLogger() + + c, err := createCommand(logger) + assert.NoError(t, err) + assert.NotNil(t, c) + + testServer, serverAssertion := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }), + ) + defer testServer.Close() + + _, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt") + assert.Error(t, err) + + assert.Equal(t, 0, serverAssertion.Len()) +} + +func TestScanCommandCanBeInterrupted(t *testing.T) { + logger, loggerBuffer := test.NewLogger() + + c, err := createCommand(logger) + assert.NoError(t, err) + assert.NotNil(t, c) + + testServer, serverAssertion := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 650) + + if r.URL.Path == "/test/" { + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusNotFound) + }), + ) + defer testServer.Close() + + go func() { + time.Sleep(time.Millisecond * 200) + _ = syscall.Kill(syscall.Getpid(), syscall.SIGINT) + }() + + _, _, err = executeCommand( + c, + "scan", + testServer.URL, + "--dictionary", + "testdata/dict2.txt", + "-v", + "--http-timeout", + "900", + ) + assert.NoError(t, err) + + assert.True(t, serverAssertion.Len() > 0) + assert.Contains(t, loggerBuffer.String(), "Received sigint") +} + func TestScanWithRemoteDictionary(t *testing.T) { logger, _ := test.NewLogger() diff --git a/pkg/scan/producer/reproducer.go b/pkg/scan/producer/reproducer.go index 1ef03e7..8f31592 100644 --- a/pkg/scan/producer/reproducer.go +++ b/pkg/scan/producer/reproducer.go @@ -40,7 +40,7 @@ func (r *ReProducer) buildReproducer() func(result scan.Result) <-chan scan.Targ go func() { defer close(resultChannel) - if _, ok := statusCodesToSkip[result.Response.StatusCode]; ok { + if _, ok := statusCodesToSkip[result.StatusCode]; ok { return } diff --git a/pkg/scan/producer/reproducer_test.go b/pkg/scan/producer/reproducer_test.go index eb6e85c..13dbe30 100644 --- a/pkg/scan/producer/reproducer_test.go +++ b/pkg/scan/producer/reproducer_test.go @@ -22,20 +22,19 @@ func TestNewReProducer(t *testing.T) { sut := producer.NewReProducer(dictionaryProducer) - result := scan.Result{ - Target: scan.Target{ + result := scan.NewResult( + scan.Target{ Path: "/home", Method: http.MethodGet, Depth: 1, }, - Response: &http.Response{ + &http.Response{ StatusCode: http.StatusOK, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/contacts"), + URL: test.MustParseUrl(t, "http://mysite/contacts"), }, }, - } + ) reproducerFunc := sut.Reproduce() reproducerChannel := reproducerFunc(result) @@ -95,20 +94,19 @@ func TestReProducerShouldProduceNothingForDepthZero(t *testing.T) { sut := producer.NewReProducer(dictionaryProducer) - result := scan.Result{ - Target: scan.Target{ + result := scan.NewResult( + scan.Target{ Path: "/home", Method: http.MethodGet, Depth: 0, }, - Response: &http.Response{ + &http.Response{ StatusCode: http.StatusOK, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/contacts"), + URL: test.MustParseUrl(t, "http://mysite/contacts"), }, }, - } + ) reproducerFunc := sut.Reproduce() reproducerChannel := reproducerFunc(result) @@ -131,20 +129,19 @@ func TestReProducerShouldProduceNothingFor404Response(t *testing.T) { sut := producer.NewReProducer(dictionaryProducer) - result := scan.Result{ - Target: scan.Target{ + result := scan.NewResult( + scan.Target{ Path: "/home", Method: http.MethodGet, Depth: 3, }, - Response: &http.Response{ + &http.Response{ StatusCode: http.StatusNotFound, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/contacts"), + URL: test.MustParseUrl(t, "http://mysite/contacts"), }, }, - } + ) reproducerFunc := sut.Reproduce() reproducerChannel := reproducerFunc(result) diff --git a/pkg/scan/result_logger.go b/pkg/scan/result_logger.go index 58c5533..514f994 100644 --- a/pkg/scan/result_logger.go +++ b/pkg/scan/result_logger.go @@ -21,12 +21,12 @@ type ResultLogger struct { } func (c *ResultLogger) Log(result Result) { - statusCode := result.Response.StatusCode + statusCode := result.StatusCode l := c.logger.WithFields(logrus.Fields{ "status-code": statusCode, "method": result.Target.Method, - "url": result.Response.Request.URL, + "url": result.URL.String(), }) if statusCode == http.StatusNotFound { diff --git a/pkg/scan/result_logger_test.go b/pkg/scan/result_logger_test.go index 6cf263d..dec72c1 100644 --- a/pkg/scan/result_logger_test.go +++ b/pkg/scan/result_logger_test.go @@ -2,6 +2,7 @@ package scan_test import ( "net/http" + "net/url" "testing" "github.com/stefanoj3/dirstalk/pkg/common/test" @@ -73,12 +74,15 @@ func TestResultLogger(t *testing.T) { resultLogger := scan.NewResultLogger(logger) resultLogger.Log( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{}, + &http.Response{ StatusCode: tc.statusCode, - Request: &http.Request{}, + Request: &http.Request{ + URL: &url.URL{}, + }, }, - }, + ), ) assert.Contains(t, loggerBuffer.String(), tc.expectedMsg) diff --git a/pkg/scan/result_summarizer.go b/pkg/scan/result_summarizer.go index 8e8a0e2..b6ba6bb 100644 --- a/pkg/scan/result_summarizer.go +++ b/pkg/scan/result_summarizer.go @@ -28,7 +28,7 @@ func (s *ResultSummarizer) Add(result Result) { s.resultsReceived++ - if result.Response.StatusCode == http.StatusNotFound { + if result.StatusCode == http.StatusNotFound { return } @@ -39,6 +39,10 @@ func (s *ResultSummarizer) Summarize() { s.mux.Lock() defer s.mux.Unlock() + sort.Slice(s.results, func(i, j int) bool { + return s.results[i].Target.Path < s.results[j].Target.Path + }) + s.printSummary() s.printTree() @@ -47,9 +51,9 @@ func (s *ResultSummarizer) Summarize() { s.out, fmt.Sprintf( "%s [%d] [%s]", - r.Response.Request.URL, - r.Response.StatusCode, - r.Response.Request.Method, + r.URL.String(), + r.StatusCode, + r.Target.Method, ), ) } @@ -65,15 +69,11 @@ func (s *ResultSummarizer) printSummary() { func (s *ResultSummarizer) printTree() { root := gotree.New("/") - sort.Slice(s.results, func(i, j int) bool { - return s.results[i].Target.Path > s.results[j].Target.Path - }) - // TODO: improve efficiency for _, r := range s.results { currentBranch := root - parts := strings.Split(r.Response.Request.URL.Path, "/") + parts := strings.Split(r.URL.Path, "/") for _, p := range parts { if len(p) == 0 { continue diff --git a/pkg/scan/result_summarizer_test.go b/pkg/scan/result_summarizer_test.go index b27a56b..56fca1e 100644 --- a/pkg/scan/result_summarizer_test.go +++ b/pkg/scan/result_summarizer_test.go @@ -16,124 +16,147 @@ func TestResultSummarizer(t *testing.T) { summarizer := scan.NewResultSummarizer(b) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodPost, + Path: "/home", + }, + &http.Response{ StatusCode: 201, Request: &http.Request{ - Method: http.MethodPost, - URL: test.MustParseUrl(t, "http://mysite/home"), + URL: test.MustParseUrl(t, "http://mysite/home"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodPost, + Path: "/home/hidden", + }, + &http.Response{ StatusCode: 201, Request: &http.Request{ - Method: http.MethodPost, - URL: test.MustParseUrl(t, "http://mysite/home/hidden"), + URL: test.MustParseUrl(t, "http://mysite/home/hidden"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/home/about", + }, + &http.Response{ StatusCode: 200, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/home/about"), + URL: test.MustParseUrl(t, "http://mysite/home/about"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/home/about/me", + }, + &http.Response{ StatusCode: 200, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/home/about/me"), + URL: test.MustParseUrl(t, "http://mysite/home/about/me"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/home/home", + }, + &http.Response{ StatusCode: 200, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/home/home"), + URL: test.MustParseUrl(t, "http://mysite/home/home"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/contacts", + }, + &http.Response{ StatusCode: 200, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/contacts"), + URL: test.MustParseUrl(t, "http://mysite/contacts"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/gibberish", + }, + &http.Response{ StatusCode: 404, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/gibberish"), + URL: test.MustParseUrl(t, "http://mysite/gibberish"), }, }, - }, + ), ) summarizer.Add( - scan.Result{ - Response: &http.Response{ + scan.NewResult( + scan.Target{ + Method: http.MethodGet, + Path: "/path/to/my/files", + }, + &http.Response{ StatusCode: 200, Request: &http.Request{ - Method: http.MethodGet, - URL: test.MustParseUrl(t, "http://mysite/path/to/my/files"), + URL: test.MustParseUrl(t, "http://mysite/path/to/my/files"), }, }, - }, + ), ) summarizer.Summarize() expectedResult := `8 requests made, 7 results found / +├── contacts ├── home -│ ├── hidden │ ├── about │ │ └── me +│ ├── hidden │ └── home -├── contacts └── path └── to └── my └── files +http://mysite/contacts [200] [GET] http://mysite/home [201] [POST] -http://mysite/home/hidden [201] [POST] http://mysite/home/about [200] [GET] http://mysite/home/about/me [200] [GET] +http://mysite/home/hidden [201] [POST] http://mysite/home/home [200] [GET] -http://mysite/contacts [200] [GET] http://mysite/path/to/my/files [200] [GET] ` - assert.Equal(t, expectedResult, b.String()) } diff --git a/pkg/scan/scanner.go b/pkg/scan/scanner.go index b6e3652..76a0b7b 100644 --- a/pkg/scan/scanner.go +++ b/pkg/scan/scanner.go @@ -19,10 +19,23 @@ type Target struct { // Result represents the result of the scan of a single URL type Result struct { - Target Target + Target Target + + StatusCode int + URL url.URL + Response *http.Response } +// NewResult creates a new instance of the Result entity based on the Target and Response +func NewResult(target Target, response *http.Response) Result { + return Result{ + Target: target, + StatusCode: response.StatusCode, + URL: *response.Request.URL, + } +} + func NewScanner( httpClient Doer, producer Producer, @@ -105,11 +118,7 @@ func (s *Scanner) processTarget( l.WithError(err).Warn("failed to close response body") } - result := Result{ - Target: target, - Response: res, - } - + result := NewResult(target, res) results <- result for newTarget := range reproducer(result) {