Skip to content

Commit

Permalink
Skip resource refresh from cloudtrail if full refresh is already sche…
Browse files Browse the repository at this point in the history
…duled
  • Loading branch information
ramanan-ravi committed Jul 4, 2024
1 parent 402d361 commit 1580060
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *CloudResourceChangesAWS) Initialize() error {
return ErrNoCloudTrailsFound
}
c.cloudTrailTrails = trails
log.Info().Msgf("Following CloudTrail Trails are monitored for events every hour to update the cloud resources in the management console")
log.Info().Msgf("Following CloudTrail Trails are monitored for events every 30 minutes to update the cloud resources in the management console")
for i, trail := range c.cloudTrailTrails {
log.Info().Msgf("%d. %s (Region: %s)", i+1, trail.Arn, trail.Region)
}
Expand Down
11 changes: 9 additions & 2 deletions cloud_resource_changes/cloud_resource_changes_aws/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ func GetSupportedAwsRegions() []string {

func getCloudTrailTrails(config util.Config) []CloudTrailTrail {
var query string
var isOrganizationTrail string
if config.IsOrganizationDeployment {
isOrganizationTrail = "and is_organization_trail = true"
}
if len(config.CloudAuditLogsIDs) == 0 {
query = "steampipe query --output json \"select * from aws_" + config.AccountID + ".aws_cloudtrail_trail where is_organization_trail = true and is_multi_region_trail = true\""
query = "steampipe query --output json \"select * from aws_" + config.AccountID + ".aws_cloudtrail_trail where is_multi_region_trail = true " + isOrganizationTrail + "\""
} else {
query = "steampipe query --output json \"select * from aws_all.aws_cloudtrail_trail where is_organization_trail = true and is_multi_region_trail = true and arn in ('" + strings.Join(config.CloudAuditLogsIDs, "', '") + "')\""
query = "steampipe query --output json \"select * from aws_all.aws_cloudtrail_trail where is_multi_region_trail = true " + isOrganizationTrail + " and arn in ('" + strings.Join(config.CloudAuditLogsIDs, "', '") + "')\""
}
cmd := exec.Command("bash", "-c", query)
stdOut, stdErr := cmd.CombinedOutput()
Expand All @@ -48,6 +52,9 @@ func getCloudTrailTrails(config util.Config) []CloudTrailTrail {
selectedTrailList = append(selectedTrailList, trail)
selectedARNs[trail.Arn] = true
}
if len(selectedTrailList) == 0 {
log.Error().Msg("cloudtrail trail arn provided does not exist or is not a multi-region trail")
}
return selectedTrailList
}

Expand Down
8 changes: 4 additions & 4 deletions output/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ func WriteScanStatus(status, scanID, scanMessage string) {
return
}

log.Debug().Msgf("Writing status: %s", status)
log.Debug().Msgf("Writing scan status: %s", status)
err = writeToFile(byteJSON, scanStatusFilename)
if err != nil {
log.Error().Msgf("Error writing status data to %s, Error: %s", scanStatusFilename, err)
log.Error().Msgf("Error writing scan status data to %s, Error: %s", scanStatusFilename, err)
return
}
}
Expand All @@ -50,10 +50,10 @@ func WriteCloudResourceRefreshStatus(nodeID, refreshStatus, refreshMessage strin
return
}

log.Debug().Msgf("Writing status: %s, %s", refreshStatus, refreshMessage)
log.Debug().Msgf("Writing refresh status: %s, %s", refreshStatus, refreshMessage)
err = writeToFile(byteJSON, cloudResourceRefreshStatusFilename)
if err != nil {
log.Error().Msgf("Error writing status data to %s, Error: %s", cloudResourceRefreshStatusFilename, err)
log.Error().Msgf("Error writing refresh status data to %s, Error: %s", cloudResourceRefreshStatusFilename, err)
return
}
}
Expand Down
23 changes: 11 additions & 12 deletions query_resource/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/deepfence/ThreatMapper/deepfence_utils/log"
"github.com/deepfence/ThreatMapper/deepfence_utils/utils"
"github.com/deepfence/cloud-scanner/output"
"github.com/deepfence/cloud-scanner/util"
_ "github.com/lib/pq"
)
Expand Down Expand Up @@ -75,7 +74,7 @@ func clearPostgresqlCache() error {
return nil
}

func QueryAndRegisterResources(config util.Config, accountsToRefresh []util.AccountsToRefresh, completeRefresh bool) []error {
func (r *ResourceRefreshService) QueryAndRegisterResources(accountsToRefresh []util.AccountsToRefresh, completeRefresh bool) []error {
if completeRefresh {
err := clearPostgresqlCache()
if err != nil {
Expand All @@ -92,29 +91,29 @@ func QueryAndRegisterResources(config util.Config, accountsToRefresh []util.Acco
defer cloudResourcesFile.Close()

for _, account := range accountsToRefresh {
output.WriteCloudResourceRefreshStatus(account.NodeID, utils.ScanStatusStarting, "")
r.SetResourceRefreshStatus(account, utils.ScanStatusStarting)
}

count := 0
var errs = make([]error, 0)
for _, account := range accountsToRefresh {
log.Debug().Msgf("Started querying resources for %v", account)
output.WriteCloudResourceRefreshStatus(account.NodeID, utils.ScanStatusInProgress, "")
r.SetResourceRefreshStatus(account, utils.ScanStatusInProgress)

for _, cloudResourceInfo := range cloudProviderToResourceMap[config.CloudProvider] {
for _, cloudResourceInfo := range cloudProviderToResourceMap[r.config.CloudProvider] {
// If ResourceTypes is empty, refresh all resource types. Otherwise, only specified ones
if len(account.ResourceTypes) > 0 {
if !util.InSlice(cloudResourceInfo.Table, account.ResourceTypes) {
continue
}
err = clearPostgresqlCacheRows(config.CloudProvider + "_" + account.AccountID + "." + cloudResourceInfo.Table)
err = clearPostgresqlCacheRows(r.config.CloudProvider + "_" + account.AccountID + "." + cloudResourceInfo.Table)
if err != nil {
errs = append(errs, err)
continue
}
}

ingestedCount, err := queryResources(account.AccountID, cloudResourceInfo, config, cloudResourcesFile)
ingestedCount, err := r.queryResources(account.AccountID, cloudResourceInfo, cloudResourcesFile)
if err != nil {
errs = append(errs, err)
}
Expand All @@ -124,7 +123,7 @@ func QueryAndRegisterResources(config util.Config, accountsToRefresh []util.Acco
}

log.Debug().Msgf("Querying resources complete for %v", account)
output.WriteCloudResourceRefreshStatus(account.NodeID, utils.ScanStatusSuccess, "")
r.SetResourceRefreshStatus(account, utils.ScanStatusSuccess)
}
log.Info().Msgf("Cloud resources ingested: %d", count)
return errs
Expand All @@ -141,10 +140,10 @@ func clearPostgresqlCacheRows(keyPrefix string) error {
return nil
}

func queryResources(accountId string, cloudResourceInfo CloudResourceInfo, config util.Config, cloudResourcesFile *os.File) (int, error) {
func (r *ResourceRefreshService) queryResources(accountId string, cloudResourceInfo CloudResourceInfo, cloudResourcesFile *os.File) (int, error) {
log.Debug().Msgf("Querying resources for %s", cloudResourceInfo.Table)

query := "steampipe query --output json \"select \\\"" + strings.Join(cloudResourceInfo.Columns[:], "\\\" , \\\"") + "\\\" from " + config.CloudProvider + "_" + strings.Replace(accountId, "-", "", -1) + "." + cloudResourceInfo.Table + " \""
query := "steampipe query --output json \"select \\\"" + strings.Join(cloudResourceInfo.Columns[:], "\\\" , \\\"") + "\\\" from " + r.config.CloudProvider + "_" + strings.Replace(accountId, "-", "", -1) + "." + cloudResourceInfo.Table + " \""
var stdOut []byte
var stdErr error
for i := 0; i <= 3; i++ {
Expand Down Expand Up @@ -176,8 +175,8 @@ func queryResources(accountId string, cloudResourceInfo CloudResourceInfo, confi

var private_dns_name string
for _, obj := range objMap {
obj["account_id"] = util.GetNodeID(config.CloudProvider, accountId)
obj["cloud_provider"] = config.CloudProvider
obj["account_id"] = util.GetNodeID(r.config.CloudProvider, accountId)
obj["cloud_provider"] = r.config.CloudProvider
if _, ok := obj["title"]; ok {
obj["name"] = fmt.Sprint(obj["title"])
delete(obj, "title")
Expand Down
157 changes: 157 additions & 0 deletions query_resource/query_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package query_resource

import (
"sync"
"sync/atomic"
"time"

"github.com/deepfence/ThreatMapper/deepfence_utils/log"
"github.com/deepfence/ThreatMapper/deepfence_utils/utils"
"github.com/deepfence/cloud-scanner/cloud_resource_changes"
"github.com/deepfence/cloud-scanner/output"
"github.com/deepfence/cloud-scanner/util"
)

type ResourceRefreshService struct {
config util.Config
resourceRefreshCount atomic.Int32
resourceRefreshStatus sync.Map
CloudResourceChanges cloud_resource_changes.CloudResourceChanges
mutex sync.Mutex
}

func NewResourceRefreshService(config util.Config) (*ResourceRefreshService, error) {
cloudResourceChanges, err := cloud_resource_changes.NewCloudResourceChanges(config)
if err != nil {
return nil, err
}

return &ResourceRefreshService{
config: config,
resourceRefreshCount: atomic.Int32{},
resourceRefreshStatus: sync.Map{},
CloudResourceChanges: cloudResourceChanges,
}, nil
}

func (r *ResourceRefreshService) Initialize() {
log.Info().Msgf("CloudResourceChanges Initialization started")
err := r.CloudResourceChanges.Initialize()
if err != nil {
log.Warn().Msgf("%+v", err)
}
log.Info().Msgf("CloudResourceChanges Initialization completed")

go r.refreshResourcesFromTrailPeriodically()
}

func (r *ResourceRefreshService) Lock() {
r.resourceRefreshCount.Add(1)
log.Debug().Msgf("Resource refresh count: %d", r.resourceRefreshCount.Load())
r.mutex.Lock()
}

func (r *ResourceRefreshService) Unlock() {
r.resourceRefreshCount.Add(-1)
log.Debug().Msgf("Resource refresh count: %d", r.resourceRefreshCount.Load())
r.mutex.Unlock()
}

func (r *ResourceRefreshService) SetResourceRefreshStatus(account util.AccountsToRefresh, refreshStatus string) {
r.resourceRefreshStatus.Store(account.AccountID, refreshStatus)
output.WriteCloudResourceRefreshStatus(account.NodeID, refreshStatus, "")
}

// SkipCloudAuditLogUpdate Weather to skip cloud audit log based resource updates
func (r *ResourceRefreshService) SkipCloudAuditLogUpdate(accountID string) bool {
var refreshStatus any
var ok bool
if refreshStatus, ok = r.resourceRefreshStatus.Load(accountID); !ok {
// Skip the resources update
return true
}
refreshStatusString := refreshStatus.(string)
if refreshStatusString == utils.ScanStatusSuccess || refreshStatusString == utils.ScanStatusFailed {
// Proceed with the resources update
return false
}
// Skip the resources update
return true
}

func (r *ResourceRefreshService) refreshResourcesFromTrailPeriodically() {
refreshTicker := time.NewTicker(2 * time.Minute) // temporarily set to 2 min for testing
for {
select {
case <-refreshTicker.C:
go func() {
r.refreshResourcesFromTrail()
}()
}
}
}

func (r *ResourceRefreshService) refreshResourcesFromTrail() {
log.Info().Msg("Started updating cloud resources")
cloudResourceTypesToRefresh, _ := r.CloudResourceChanges.GetResourceTypesToRefresh()
if len(cloudResourceTypesToRefresh) == 0 {
return
}
var accountsToRefresh []util.AccountsToRefresh
for accountID, resourceTypes := range cloudResourceTypesToRefresh {
if r.SkipCloudAuditLogUpdate(accountID) {
log.Debug().Msgf("Skipping resource refresh updation for account %s, account wide refresh already scheduled", accountID)
continue
}

log.Debug().Msgf("Resource refresh updation for account %s, resource types: %v", accountID, resourceTypes)
accountsToRefresh = append(accountsToRefresh, util.AccountsToRefresh{
AccountID: accountID,
NodeID: util.GetNodeID(r.config.CloudProvider, accountID),
ResourceTypes: resourceTypes,
})
}

r.FetchCloudAccountResources(accountsToRefresh, false)
log.Info().Msg("Updating cloud resources complete")
}

func (r *ResourceRefreshService) GetRefreshCount() int32 {
return r.resourceRefreshCount.Load()
}

// FetchCloudResources Fetch cloud resources from all accounts
func (r *ResourceRefreshService) FetchCloudResources(organizationAccounts []util.MonitoredAccount) {
log.Info().Msg("Querying cloud resources")

var accountsToRefresh []util.AccountsToRefresh
if r.config.IsOrganizationDeployment {
for _, monitoredAccount := range organizationAccounts {
accountsToRefresh = append(accountsToRefresh, util.AccountsToRefresh{
AccountID: monitoredAccount.AccountID,
NodeID: monitoredAccount.NodeID,
})
}
} else {
accountsToRefresh = []util.AccountsToRefresh{
{
AccountID: r.config.AccountID,
NodeID: r.config.NodeID,
},
}
}
r.FetchCloudAccountResources(accountsToRefresh, true)
log.Info().Msg("Querying cloud resources complete")
}

// FetchCloudAccountResources Fetch cloud resources from selected accounts and resource types
func (r *ResourceRefreshService) FetchCloudAccountResources(accountsToRefresh []util.AccountsToRefresh, completeRefresh bool) {
// Only one cloud account's resources are refreshed at a time
r.Lock()
defer r.Unlock()

errorsCollected := r.QueryAndRegisterResources(accountsToRefresh, completeRefresh)
if len(errorsCollected) > 0 {
log.Error().Msgf("Error in sending resources, errors: %+v", errorsCollected)
}
}
67 changes: 0 additions & 67 deletions service/query_service.go

This file was deleted.

Loading

0 comments on commit 1580060

Please sign in to comment.