diff --git a/examples/sample-rules2.json b/examples/sample-rules2.json new file mode 100644 index 0000000..9621f8c --- /dev/null +++ b/examples/sample-rules2.json @@ -0,0 +1,16 @@ +[ + { + "sha256": "d84db96af8c2e60ac4c851a21ec460f6f84e0235beb17d24a78712b9b021ed57", + "type": "CERTIFICATE", + "policy": "ALLOWLIST", + "custom_msg": "", + "description": "Software Signing by Apple Inc." + }, + { + "sha256": "345a8e098bd04794aaeefda8c9ef56a0bf3d3706d67d35bc0e23f11bb3bffce5", + "type": "CERTIFICATE", + "policy": "ALLOWLIST", + "custom_msg": "", + "description": "Developer ID Application: Google, Inc. (EQHXZ8M8AV)" + } +] diff --git a/internal/cli/rules/rules-export.go b/internal/cli/rules/rules-export.go index 0d8d45a..3389bf5 100644 --- a/internal/cli/rules/rules-export.go +++ b/internal/cli/rules/rules-export.go @@ -1,16 +1,22 @@ package rules import ( + "encoding/json" + "fmt" + "io/ioutil" + "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/airbnb/rudolph/internal/csv" "github.com/airbnb/rudolph/pkg/dynamodb" "github.com/airbnb/rudolph/pkg/model/globalrules" + "github.com/airbnb/rudolph/pkg/types" ) func addRuleExportCommand() { var filename string + var format string var ruleExportCmd = &cobra.Command{ Use: "export ", Aliases: []string{"rules-export"}, @@ -23,19 +29,74 @@ func addRuleExportCommand() { dynamodbClient := dynamodb.GetClient(table, region) - return runExport(dynamodbClient, filename) + return runExport(dynamodbClient, filename, format) }, } ruleExportCmd.Flags().StringVarP(&filename, "filename", "f", "", "The filename") _ = ruleExportCmd.MarkFlagRequired("filename") + ruleExportCmd.Flags().StringVarP(&format, "fileformat", "t", "csv", "File format (one of: [json|csv])") + RulesCmd.AddCommand(ruleExportCmd) } func runExport( client dynamodb.QueryAPI, filename string, + format string, +) (err error) { + switch format { + case "json": + return runJsonExport(client, filename) + case "csv": + return runCsvExport(client, filename) + } + return +} + +type fileRule struct { + RuleType types.RuleType `json:"type"` + Policy types.Policy `json:"policy"` + SHA256 string `json:"sha256"` + CustomMessage string `json:"custom_msg,omitempty"` + Description string `json:"description"` +} + +func runJsonExport(client dynamodb.QueryAPI, filename string) (err error) { + var jsonRules []fileRule + fmt.Println("Querying rules from DynamoDB...") + total, err := getRules(client, func(rule globalrules.GlobalRuleRow) (err error) { + jsonRules = append(jsonRules, fileRule{ + SHA256: rule.SHA256, + RuleType: rule.RuleType, + Policy: rule.Policy, + CustomMessage: rule.CustomMessage, + Description: rule.Description, + }) + return + }) + if err != nil { + return + } + + jsondata, err := json.MarshalIndent(jsonRules, "", " ") + if err != nil { + return + } + err = ioutil.WriteFile(filename, jsondata, 0644) + if err != nil { + return + } + + fmt.Printf("rules discovered: %d, rules written: %d\n", total, len(jsonRules)) + + return +} + +func runCsvExport( + client dynamodb.QueryAPI, + filename string, ) (err error) { csvRules := make(chan []string) @@ -52,6 +113,45 @@ func runExport( panic(err) } + fmt.Println("Querying rules from DynamoDB...") + var totalWritten int64 + total, err := getRules(client, func(rule globalrules.GlobalRuleRow) (err error) { + ruleType, err := rule.RuleType.MarshalText() + if err != nil { + return + } + policy, err := rule.Policy.MarshalText() + if err != nil { + return + } + record := []string{ + rule.SHA256, + string(ruleType), + string(policy), + rule.CustomMessage, + rule.Description, + } + if err != nil { + return + } + + totalWritten += 1 + csvRules <- record + return + }) + if err != nil { + return + } + + close(csvRules) + wg.Wait() + + fmt.Printf("rules discovered: %d, rules written: %d\n", total, totalWritten) + + return +} + +func getRules(client dynamodb.QueryAPI, callback func(globalrules.GlobalRuleRow) error) (total int64, err error) { var key *dynamodb.PrimaryKey for { rules, nextkey, inerr := globalrules.GetPaginatedGlobalRules(client, 50, key) @@ -64,17 +164,11 @@ func runExport( } for _, rule := range *rules { - ruleType, _ := rule.RuleType.MarshalText() - policy, _ := rule.Policy.MarshalText() - record := []string{ - rule.SHA256, - string(ruleType), - string(policy), - rule.CustomMessage, - rule.Description, + total += 1 + err = callback(rule) + if err != nil { + return } - - csvRules <- record } if nextkey == nil { @@ -82,9 +176,5 @@ func runExport( } key = nextkey } - close(csvRules) - - wg.Wait() - - return nil + return } diff --git a/internal/cli/rules/rules-import.go b/internal/cli/rules/rules-import.go index 80dad22..dcd77aa 100644 --- a/internal/cli/rules/rules-import.go +++ b/internal/cli/rules/rules-import.go @@ -1,7 +1,12 @@ package rules import ( + "encoding/json" + "errors" "fmt" + "io/ioutil" + "os" + "strings" "sync" "sync/atomic" @@ -61,12 +66,35 @@ func runImport( filename string, numWorkers int, ) error { + if strings.HasSuffix(filename, ".csv") { + return runCsvImport(client, timeProvider, filename, numWorkers) + } else if strings.HasSuffix(filename, ".json") { + return runJsonImport(client, timeProvider, filename, numWorkers) + } - // ParseCsvFile returns a data channel and an optional error if any issues - // occurred while opening the file for reading - data, err := csv.ParseCsvFile(filename) + return errors.New("unrecognized file extension") +} + +func runJsonImport( + client dynamodb.DynamoDBClient, + timeProvider clock.TimeProvider, + filename string, + numWorkers int, +) (err error) { + fp, err := os.Open(filename) if err != nil { - return err + return + } + defer fp.Close() + contents, err := ioutil.ReadAll(fp) + if err != nil { + return + } + + var rules []fileRule + err = json.Unmarshal(contents, &rules) + if err != nil { + return } // Track a total number of lines processed @@ -75,32 +103,75 @@ func runImport( var total uint64 // Start the workers + rulesBuffer := make(chan fileRule) var wg sync.WaitGroup for w := 0; w < numWorkers; w++ { wg.Add(1) go func() { defer wg.Done() // ensure Done is called after this worker is complete - ddbWriter(client, timeProvider, data, &total) + ddbWriter( + client, + timeProvider, + rulesBuffer, + &total, + ) }() } - // chill + // Shovel all the json-parsed rules into the worker queue + for _, rule := range rules { + rulesBuffer <- rule + } + close(rulesBuffer) + + // Chill wg.Wait() fmt.Println("processed lines:", total) - return nil + return } -func ddbWriter( +func runCsvImport( client dynamodb.DynamoDBClient, timeProvider clock.TimeProvider, - lines chan map[string]string, - total *uint64, -) { - for line := range lines { - atomic.AddUint64(total, 1) + filename string, + numWorkers int, +) error { + // ParseCsvFile returns a data channel and an optional error if any issues + // occurred while opening the file for reading + data, err := csv.ParseCsvFile(filename) + if err != nil { + return err + } + // Channel for csv parsing and workers to communicate over + rules := make(chan fileRule) + + // Track a total number of lines processed + // This gets passed to workers and atomic.Add is + // used to increment in a thread-safe way + var total uint64 + + // Start the workers + // Fanning out workers allows us to make multiple HTTP requests concurrently which can + // improve performance assuming we aren't network I/O bottlenecked or something. + var wg sync.WaitGroup + for w := 0; w < numWorkers; w++ { + wg.Add(1) + go func() { + defer wg.Done() // ensure Done is called after this worker is complete + ddbWriter( + client, + timeProvider, + rules, + &total, + ) + }() + } + + // Start taking lines from the csv and shoveling them into the workers + for line := range data { sha256, ok := line["sha256"] if !ok { panic("no sha256") @@ -117,6 +188,11 @@ func ddbWriter( if !ok { description = "" } + customMsg, ok := line["custom_msg"] + if !ok { + customMsg = "" + } + var ruleType types.RuleType err := ruleType.UnmarshalText([]byte(ruleTypeStr)) if err != nil { @@ -128,14 +204,42 @@ func ddbWriter( panic("invalid policy") } + rules <- fileRule{ + SHA256: sha256, + RuleType: ruleType, + Policy: policy, + Description: description, + CustomMessage: customMsg, + } + } + close(rules) + + // chill + wg.Wait() + + fmt.Println("processed lines:", total) + + return nil +} + +func ddbWriter( + client dynamodb.DynamoDBClient, + timeProvider clock.TimeProvider, + rules chan fileRule, + total *uint64, +) { + for rule := range rules { + var err error + atomic.AddUint64(total, 1) + suffix := "" - if ruleType == types.Certificate { + if rule.RuleType == types.Certificate { suffix = " (Cert)" } - if policy == types.RulePolicyRemove { - fmt.Printf(" Removing rule: [%s]\n", sha256) - sortkey := rudolphrules.RuleSortKeyFromTypeSHA(sha256, ruleType) + if rule.Policy == types.RulePolicyRemove { + fmt.Printf(" Removing rule: [%s]\n", rule.SHA256) + sortkey := rudolphrules.RuleSortKeyFromTypeSHA(rule.SHA256, rule.RuleType) err = globalrules.RemoveGlobalRule( timeProvider, client, @@ -145,14 +249,14 @@ func ddbWriter( ) } else { - fmt.Printf(" Writing rule: [%s] %s%s\n", policyStr, sha256, suffix) + fmt.Printf(" Writing rule: [%+v] %s%s\n", rule.Policy, rule.SHA256, suffix) err = globalrules.AddNewGlobalRule( timeProvider, client, - sha256, - ruleType, - policy, - description, + rule.SHA256, + rule.RuleType, + rule.Policy, + rule.Description, ) } diff --git a/pkg/model/feedrules/query.go b/pkg/model/feedrules/query.go index 15c4700..6429a28 100644 --- a/pkg/model/feedrules/query.go +++ b/pkg/model/feedrules/query.go @@ -65,6 +65,6 @@ func GetPaginatedFeedRules(client dynamodb.QueryAPI, limit int, exclusiveStartKe err = errors.Wrap(err, "failed to unmarshal result from DynamoDB") return } - log.Printf(" got %d items from query.", len(*items)) + // log.Printf(" got %d items from query.", len(*items)) return } diff --git a/pkg/model/globalrules/query.go b/pkg/model/globalrules/query.go index 4c52094..8a4c233 100644 --- a/pkg/model/globalrules/query.go +++ b/pkg/model/globalrules/query.go @@ -1,8 +1,6 @@ package globalrules import ( - "log" - "github.com/airbnb/rudolph/pkg/dynamodb" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" @@ -59,7 +57,6 @@ func GetPaginatedGlobalRules(client dynamodb.QueryAPI, limit int, exclusiveStart err = errors.Wrap(err, "failed to unmarshall LastEvaluatedKey") return } - log.Printf(" lastEvaluatedKey: %+v", lastEvaluatedKey) } err = attributevalue.UnmarshalListOfMaps(result.Items, &items) @@ -67,6 +64,5 @@ func GetPaginatedGlobalRules(client dynamodb.QueryAPI, limit int, exclusiveStart err = errors.Wrap(err, "failed to unmarshal result from DynamoDB") return } - log.Printf(" got %d items from query.", len(*items)) return } diff --git a/pkg/model/machinerules/query.go b/pkg/model/machinerules/query.go index 69d056f..7e32032 100644 --- a/pkg/model/machinerules/query.go +++ b/pkg/model/machinerules/query.go @@ -28,21 +28,13 @@ func GetPrimaryKeysByMachineIDWhereMarkedForDeletion(client dynamodb.QueryAPI, m ProjectionExpression: aws.String("PK, SK"), } - // log.Printf("DDB Query Input:\n%+v", input) - output, err := client.Query(&input) - // log.Printf("Error:\n%+v", err) - // log.Printf("DDB Query Output:\n%+v", output) - // log.Printf("Discovered %d items", len(output.Items)) - if err != nil { return } err = attributevalue.UnmarshalListOfMaps(output.Items, &keys) - // log.Printf("Keys:\n%+v", *keys) - return } @@ -61,8 +53,6 @@ func GetMachineRules(client dynamodb.QueryAPI, machineID string) (items *[]Machi KeyConditionExpression: keyConditionExpression, } - // log.Printf("Executing DynamoDB Query:\n%+v", input) - result, err := client.Query(input) if err != nil { err = errors.Wrapf(err, "failed to read rules from DynamoDB for partitionKey %q", partitionKey) @@ -74,6 +64,5 @@ func GetMachineRules(client dynamodb.QueryAPI, machineID string) (items *[]Machi err = errors.Wrap(err, "failed to unmarshal result from DynamoDB") return } - // log.Printf(" got %d items from query.", len(*items)) return }