Skip to content

Commit

Permalink
refactor and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonhadfield committed Jul 7, 2024
1 parent b2b0906 commit 1257965
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 54 deletions.
1 change: 0 additions & 1 deletion providers/ipqs/ipqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ func ToPtr[T any](v T) *T {

func TestRateHost(t *testing.T) {
rc := providers.RatingConfig{}
// prc := providers.ProviderRatingConfig{}
rc.Global.HighThreatCountryCodes = []string{"CN"}
rc.Global.MediumThreatCountryCodes = []string{"US"}
rc.ProviderRatingsConfigs.IPQS.ProxyScore = ToPtr(float64(3))
Expand Down
4 changes: 2 additions & 2 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ type RatingConfig struct {
MediumThreatCountryMatchScore float64 `json:"mediumThreatCountryMatchScore,omitempty"`
} `json:"shodan"`
VirusTotal struct {
Suspicious float64 `json:"suspicious,omitempty"`
Malicious float64 `json:"malicious,omitempty"`
SuspiciousScore *float64 `json:"suspiciousScore,omitempty"`
MaliciousScore *float64 `json:"maliciousScore,omitempty"`
} `json:"virustotal"`
} `json:"providers"`
}
Expand Down
179 changes: 128 additions & 51 deletions providers/virustotal/virustotal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@ import (
)

const (
ProviderName = "virustotal"
APIURL = "https://www.virustotal.com"
HostIPPath = "/api/v3/ip_addresses"
IndentPipeHyphens = " |-----"
portLastModifiedFormat = "2006-01-02T15:04:05.999999"
ResultTTL = 12 * time.Hour
dataColumnNo = 2
veryHighScore = 10
APITimeout = 10 * time.Second
ProviderName = "virustotal"
APIURL = "https://www.virustotal.com"
HostIPPath = "/api/v3/ip_addresses"
IndentPipeHyphens = " |-----"
ResultTTL = 12 * time.Hour
dataColumnNo = 2
veryHighScore = 10
defaultHarmlessScore = 0
defaultSuspiciousScore = 7
defaultMaliciousScore = 10
defaultMediumThreatCountryMatchScore = 6.0
defaultHighThreatCountryMatchScore = 9.0
APITimeout = 10 * time.Second
)

type Config struct {
Expand All @@ -50,43 +54,110 @@ type ProviderClient struct {
session.Session
}

func (c *ProviderClient) RateHostData(findRes []byte, ratingConfigJSON []byte) (providers.RateResult, error) {
var doc HostSearchResult
// chooseScore returns the user defined score if provided and higher than the running total
func chooseScore(def, runningTotal float64, user *float64) float64 {
if user != nil {
if *user > runningTotal {
return *user
}

var rateResult providers.RateResult
return runningTotal
}

// return default if no user defined score
return def
}

func loadRatingConfig(ratingConfigJSON []byte) (providers.RatingConfig, error) {
var ratingConfig providers.RatingConfig

if err := json.Unmarshal(findRes, &doc); err != nil {
return providers.RateResult{}, fmt.Errorf("failed to unmarshall virustotal data: %w", err)
if err := json.Unmarshal(ratingConfigJSON, &ratingConfig); err != nil {
return providers.RatingConfig{}, fmt.Errorf("error unmarshalling rating config: %w", err)
}

// assume no result if no host id
if doc.Data.ID == "" {
return providers.RateResult{}, nil
return ratingConfig, nil
}

func loadFindHostResults(in []byte) (HostSearchResult, error) {
var doc HostSearchResult

if err := json.Unmarshal(in, &doc); err != nil {
return HostSearchResult{}, fmt.Errorf("error unmarshalling find result: %w", err)
}

return doc, nil
}

func countryCodeInCodes(countryCode string, codes []string) bool {
for _, c := range codes {
if strings.EqualFold(countryCode, c) {
return true
}
}

return false
}

func rateHost(attrs HostSearchResultDataAttributes, ratingConfig providers.RatingConfig) providers.RateResult {
var rateResult providers.RateResult

// cannot reach here unless detected
rateResult.Detected = true

if attrs.Country != "" {
if countryCodeInCodes(attrs.Country, ratingConfig.Global.MediumThreatCountryCodes) {
// if user provided score, then use that, otherwise use default
rateResult.Score = chooseScore(defaultMediumThreatCountryMatchScore, rateResult.Score, ratingConfig.ProviderRatingsConfigs.IPQS.MediumThreatCountryMatchScore)
}

if countryCodeInCodes(attrs.Country, ratingConfig.Global.HighThreatCountryCodes) {
// if user provided score, then use that, otherwise use default
rateResult.Score = chooseScore(defaultHighThreatCountryMatchScore, rateResult.Score, ratingConfig.ProviderRatingsConfigs.IPQS.HighThreatCountryMatchScore)
}
}

switch {
case doc.Data.Attributes.LastAnalysisStats.Malicious > 0:
rateResult.Score += veryHighScore
case attrs.LastAnalysisStats.Malicious > 0:
rateResult.Score = chooseScore(defaultMaliciousScore,
rateResult.Score,
ratingConfig.ProviderRatingsConfigs.VirusTotal.MaliciousScore)
rateResult.Threat = "very high"
rateResult.Reasons = append(rateResult.Reasons, "malicious")
case doc.Data.Attributes.LastAnalysisStats.Suspicious > 0:
case attrs.LastAnalysisStats.Suspicious > 0:
rateResult.Threat = "high"
rateResult.Score += 7
rateResult.Score = chooseScore(defaultSuspiciousScore,
rateResult.Score,
ratingConfig.ProviderRatingsConfigs.VirusTotal.SuspiciousScore)
rateResult.Reasons = append(rateResult.Reasons, "suspicious")
case doc.Data.Attributes.LastAnalysisStats.Harmless > 0 || doc.Data.Attributes.LastAnalysisStats.Undetected > 0:
case attrs.LastAnalysisStats.Harmless > 0 || attrs.LastAnalysisStats.Undetected > 0:
rateResult.Threat = "low"
rateResult.Score += 3
rateResult.Score = defaultHarmlessScore
rateResult.Reasons = append(rateResult.Reasons, "harmless")
}

if rateResult.Score > veryHighScore {
rateResult.Score = veryHighScore
}

return rateResult, nil
return rateResult
}

func (c *ProviderClient) RateHostData(findRes []byte, ratingConfigJSON []byte) (providers.RateResult, error) {
hostData, err := loadFindHostResults(findRes)
if err != nil {
return providers.RateResult{}, fmt.Errorf("error loading find host results: %w", err)
}

ratingConfig, err := loadRatingConfig(ratingConfigJSON)
if err != nil {
return providers.RateResult{}, fmt.Errorf("error loading rating config: %w", err)
}

if hostData.Data.ID == "" {
return providers.RateResult{}, fmt.Errorf("no host id found")
}

return rateHost(hostData.Data.Attributes, ratingConfig), nil
}

func (c *ProviderClient) Enabled() bool {
Expand Down Expand Up @@ -855,39 +926,45 @@ type LastAnalysisResults struct {
} `json:"zvelo,omitempty"`
}

type LastAnalysisStats struct {
Malicious int `json:"malicious,omitempty"`
Suspicious int `json:"suspicious,omitempty"`
Undetected int `json:"undetected,omitempty"`
Harmless int `json:"harmless,omitempty"`
Timeout int `json:"timeout,omitempty"`
}

type TotalVotes struct {
Harmless int `json:"harmless,omitempty"`
Malicious int `json:"malicious,omitempty"`
}

type HostSearchResultDataAttributes struct {
LastAnalysisStats LastAnalysisStats `json:"last_analysis_stats,omitempty"`
LastAnalysisResults LastAnalysisResults `json:"last_analysis_results,omitempty"`
LastModificationDate int `json:"last_modification_date,omitempty"`
LastAnalysisDate int `json:"last_analysis_date,omitempty"`

Whois string `json:"whois,omitempty"`
WhoisDate int `json:"whois_date,omitempty"`
Reputation int `json:"reputation,omitempty"`
Country string `json:"country,omitempty"`
TotalVotes TotalVotes `json:"total_votes,omitempty"`
Continent string `json:"continent,omitempty"`
Asn int `json:"asn,omitempty"`
AsOwner string `json:"as_owner,omitempty"`
Network string `json:"network,omitempty"`
Tags []any `json:"tags,omitempty"`
RegionalInternetRegistry string `json:"regional_internet_registry,omitempty"`
}

type HostSearchResultData struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Links struct {
Self string `json:"self,omitempty"`
} `json:"links,omitempty"`
Attributes struct {
LastAnalysisStats struct {
Malicious int `json:"malicious,omitempty"`
Suspicious int `json:"suspicious,omitempty"`
Undetected int `json:"undetected,omitempty"`
Harmless int `json:"harmless,omitempty"`
Timeout int `json:"timeout,omitempty"`
} `json:"last_analysis_stats,omitempty"`
LastAnalysisResults LastAnalysisResults `json:"last_analysis_results,omitempty"`
LastModificationDate int `json:"last_modification_date,omitempty"`
LastAnalysisDate int `json:"last_analysis_date,omitempty"`

Whois string `json:"whois,omitempty"`
WhoisDate int `json:"whois_date,omitempty"`
Reputation int `json:"reputation,omitempty"`
Country string `json:"country,omitempty"`
TotalVotes struct {
Harmless int `json:"harmless,omitempty"`
Malicious int `json:"malicious,omitempty"`
} `json:"total_votes,omitempty"`
Continent string `json:"continent,omitempty"`
Asn int `json:"asn,omitempty"`
AsOwner string `json:"as_owner,omitempty"`
Network string `json:"network,omitempty"`
Tags []any `json:"tags,omitempty"`
RegionalInternetRegistry string `json:"regional_internet_registry,omitempty"`
} `json:"attributes,omitempty"`
Attributes HostSearchResultDataAttributes `json:"attributes,omitempty"`
}

type AnalysisResultData struct {
Expand Down
38 changes: 38 additions & 0 deletions providers/virustotal/virustotal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,49 @@ import (
"os"
"testing"

"github.com/jonhadfield/ipscout/providers"
"github.com/jonhadfield/ipscout/session"

"github.com/stretchr/testify/require"
)

func TestRateHost(t *testing.T) {
rc := providers.RatingConfig{}
rc.Global.HighThreatCountryCodes = []string{"CN"}
rc.Global.MediumThreatCountryCodes = []string{"US"}
rc.ProviderRatingsConfigs.VirusTotal.SuspiciousScore = ToPtr(float64(7.6))
rc.ProviderRatingsConfigs.VirusTotal.MaliciousScore = ToPtr(float64(9.2))
attrs := HostSearchResultDataAttributes{
LastAnalysisStats: LastAnalysisStats{},
LastAnalysisResults: LastAnalysisResults{},
LastModificationDate: 0,
LastAnalysisDate: 0,
Reputation: 0,
Country: "",
TotalVotes: TotalVotes{},
Asn: 0,
}

res := rateHost(attrs, rc)
// nothing detected so should return 0
require.Equal(t, float64(0), res.Score)

// expect country US to bring score up to 3
attrs.Country = "US"
res = rateHost(attrs, rc)
require.Equal(t, float64(6), res.Score)

// setting a report of suspicious should bring score up to 7.6
attrs.LastAnalysisStats.Suspicious = 2
res = rateHost(attrs, rc)
require.Equal(t, float64(7.6), res.Score)

// setting a report of malicious should bring score up to 9.2
attrs.LastAnalysisStats.Malicious = 1
res = rateHost(attrs, rc)
require.Equal(t, float64(9.2), res.Score)
}

//nolint:funlen
func TestVirusTotalHostQuery(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit 1257965

Please sign in to comment.