Skip to content

Commit

Permalink
Introducing sigint handling
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanoj3 committed Jul 28, 2019
1 parent f57c5d1 commit fc31a05
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 94 deletions.
6 changes: 3 additions & 3 deletions functional-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ SCAN_RESULT=$(./dist/dirstalk scan -d resources/tests/dictionary.txt http://loca
assert_contains "$SCAN_RESULT" "/index" "result expected when performing scan"
assert_contains "$SCAN_RESULT" "/index/home" "result expected when performing scan"
assert_contains "$SCAN_RESULT" "8 requests made, 3 results found" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" "├── index" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" "└── home" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" "└── home" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" "├── home" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" "└── index" "a recap was expected when performing a scan"
assert_contains "$SCAN_RESULT" " └── home" "a recap was expected when performing a scan"

assert_not_contains "$SCAN_RESULT" "error" "no error is expected for a successful scan"

Expand Down
26 changes: 21 additions & 5 deletions pkg/cmd/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"os/signal"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -179,14 +181,28 @@ func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error {

resultLogger := scan.NewResultLogger(logger)
summarizer := scan.NewResultSummarizer(logger.Out)
for result := range s.Scan(u, cnf.Threads) {
resultLogger.Log(result)
summarizer.Add(result)

osSigint := make(chan os.Signal, 1)
signal.Notify(osSigint, os.Interrupt)

finishFunc := func() {
summarizer.Summarize()
logger.Info("Finished scan")
}

summarizer.Summarize()
for result := range s.Scan(u, cnf.Threads) {
select {
case <-osSigint:
logger.Info("Received sigint, terminating...")
finishFunc()
return nil
default:
resultLogger.Log(result)
summarizer.Add(result)
}
}

logger.Info("Finished scan")
finishFunc()

return nil
}
Expand Down
64 changes: 63 additions & 1 deletion pkg/cmd/scan_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"sync"
"syscall"
"testing"
"time"

Expand All @@ -30,7 +31,6 @@ func TestScanCommand(t *testing.T) {
}

w.WriteHeader(http.StatusNotFound)

}),
)
defer testServer.Close()
Expand Down Expand Up @@ -70,6 +70,68 @@ func TestScanCommand(t *testing.T) {
assert.Equal(t, expectedRequests, requestsMap)
}

func TestScanWithNoTargetShouldErr(t *testing.T) {
logger, _ := test.NewLogger()

c, err := createCommand(logger)
assert.NoError(t, err)
assert.NotNil(t, c)

testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}),
)
defer testServer.Close()

_, _, err = executeCommand(c, "scan", "--dictionary", "testdata/dict2.txt")
assert.Error(t, err)

assert.Equal(t, 0, serverAssertion.Len())
}

func TestScanCommandCanBeInterrupted(t *testing.T) {
logger, loggerBuffer := test.NewLogger()

c, err := createCommand(logger)
assert.NoError(t, err)
assert.NotNil(t, c)

testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 650)

if r.URL.Path == "/test/" {
w.WriteHeader(http.StatusOK)
return
}

w.WriteHeader(http.StatusNotFound)
}),
)
defer testServer.Close()

go func() {
time.Sleep(time.Millisecond * 200)
_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
}()

_, _, err = executeCommand(
c,
"scan",
testServer.URL,
"--dictionary",
"testdata/dict2.txt",
"-v",
"--http-timeout",
"900",
)
assert.NoError(t, err)

assert.True(t, serverAssertion.Len() > 0)
assert.Contains(t, loggerBuffer.String(), "Received sigint")
}

func TestScanWithRemoteDictionary(t *testing.T) {
logger, _ := test.NewLogger()

Expand Down
2 changes: 1 addition & 1 deletion pkg/scan/producer/reproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (r *ReProducer) buildReproducer() func(result scan.Result) <-chan scan.Targ
go func() {
defer close(resultChannel)

if _, ok := statusCodesToSkip[result.Response.StatusCode]; ok {
if _, ok := statusCodesToSkip[result.StatusCode]; ok {
return
}

Expand Down
33 changes: 15 additions & 18 deletions pkg/scan/producer/reproducer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,19 @@ func TestNewReProducer(t *testing.T) {

sut := producer.NewReProducer(dictionaryProducer)

result := scan.Result{
Target: scan.Target{
result := scan.NewResult(
scan.Target{
Path: "/home",
Method: http.MethodGet,
Depth: 1,
},
Response: &http.Response{
&http.Response{
StatusCode: http.StatusOK,
Request: &http.Request{
Method: http.MethodGet,
URL: test.MustParseUrl(t, "http://mysite/contacts"),
URL: test.MustParseUrl(t, "http://mysite/contacts"),
},
},
}
)

reproducerFunc := sut.Reproduce()
reproducerChannel := reproducerFunc(result)
Expand Down Expand Up @@ -95,20 +94,19 @@ func TestReProducerShouldProduceNothingForDepthZero(t *testing.T) {

sut := producer.NewReProducer(dictionaryProducer)

result := scan.Result{
Target: scan.Target{
result := scan.NewResult(
scan.Target{
Path: "/home",
Method: http.MethodGet,
Depth: 0,
},
Response: &http.Response{
&http.Response{
StatusCode: http.StatusOK,
Request: &http.Request{
Method: http.MethodGet,
URL: test.MustParseUrl(t, "http://mysite/contacts"),
URL: test.MustParseUrl(t, "http://mysite/contacts"),
},
},
}
)

reproducerFunc := sut.Reproduce()
reproducerChannel := reproducerFunc(result)
Expand All @@ -131,20 +129,19 @@ func TestReProducerShouldProduceNothingFor404Response(t *testing.T) {

sut := producer.NewReProducer(dictionaryProducer)

result := scan.Result{
Target: scan.Target{
result := scan.NewResult(
scan.Target{
Path: "/home",
Method: http.MethodGet,
Depth: 3,
},
Response: &http.Response{
&http.Response{
StatusCode: http.StatusNotFound,
Request: &http.Request{
Method: http.MethodGet,
URL: test.MustParseUrl(t, "http://mysite/contacts"),
URL: test.MustParseUrl(t, "http://mysite/contacts"),
},
},
}
)

reproducerFunc := sut.Reproduce()
reproducerChannel := reproducerFunc(result)
Expand Down
4 changes: 2 additions & 2 deletions pkg/scan/result_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ type ResultLogger struct {
}

func (c *ResultLogger) Log(result Result) {
statusCode := result.Response.StatusCode
statusCode := result.StatusCode

l := c.logger.WithFields(logrus.Fields{
"status-code": statusCode,
"method": result.Target.Method,
"url": result.Response.Request.URL,
"url": result.URL.String(),
})

if statusCode == http.StatusNotFound {
Expand Down
12 changes: 8 additions & 4 deletions pkg/scan/result_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scan_test

import (
"net/http"
"net/url"
"testing"

"github.com/stefanoj3/dirstalk/pkg/common/test"
Expand Down Expand Up @@ -73,12 +74,15 @@ func TestResultLogger(t *testing.T) {

resultLogger := scan.NewResultLogger(logger)
resultLogger.Log(
scan.Result{
Response: &http.Response{
scan.NewResult(
scan.Target{},
&http.Response{
StatusCode: tc.statusCode,
Request: &http.Request{},
Request: &http.Request{
URL: &url.URL{},
},
},
},
),
)

assert.Contains(t, loggerBuffer.String(), tc.expectedMsg)
Expand Down
18 changes: 9 additions & 9 deletions pkg/scan/result_summarizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (s *ResultSummarizer) Add(result Result) {

s.resultsReceived++

if result.Response.StatusCode == http.StatusNotFound {
if result.StatusCode == http.StatusNotFound {
return
}

Expand All @@ -39,6 +39,10 @@ func (s *ResultSummarizer) Summarize() {
s.mux.Lock()
defer s.mux.Unlock()

sort.Slice(s.results, func(i, j int) bool {
return s.results[i].Target.Path < s.results[j].Target.Path
})

s.printSummary()
s.printTree()

Expand All @@ -47,9 +51,9 @@ func (s *ResultSummarizer) Summarize() {
s.out,
fmt.Sprintf(
"%s [%d] [%s]",
r.Response.Request.URL,
r.Response.StatusCode,
r.Response.Request.Method,
r.URL.String(),
r.StatusCode,
r.Target.Method,
),
)
}
Expand All @@ -65,15 +69,11 @@ func (s *ResultSummarizer) printSummary() {
func (s *ResultSummarizer) printTree() {
root := gotree.New("/")

sort.Slice(s.results, func(i, j int) bool {
return s.results[i].Target.Path > s.results[j].Target.Path
})

// TODO: improve efficiency
for _, r := range s.results {
currentBranch := root

parts := strings.Split(r.Response.Request.URL.Path, "/")
parts := strings.Split(r.URL.Path, "/")
for _, p := range parts {
if len(p) == 0 {
continue
Expand Down
Loading

0 comments on commit fc31a05

Please sign in to comment.