From f60948fcf47d0cdb7a0025548772b7ce6c493f48 Mon Sep 17 00:00:00 2001 From: rtfmkiesel <79413747+rtfmkiesel@users.noreply.github.com> Date: Mon, 25 Dec 2023 17:45:46 +0100 Subject: [PATCH] refactoring & house cleaning --- cmd/loldrivers-client/loldrivers-client.go | 133 ++-------- pkg/checksums/checksums.go | 209 ++------------- pkg/checksums/filewalker.go | 55 ++++ pkg/checksums/runner.go | 59 +++++ pkg/filesystem/filesystem.go | 93 ------- pkg/logger/logger.go | 19 -- pkg/loldrivers/loldrivers.go | 282 +++++---------------- pkg/loldrivers/parse.go | 169 ++++++++++++ pkg/loldrivers/parse_test.go | 7 +- pkg/options/options.go | 110 ++++++++ pkg/output/output.go | 58 ----- pkg/result/result.go | 52 ++++ 12 files changed, 546 insertions(+), 700 deletions(-) create mode 100644 pkg/checksums/filewalker.go create mode 100644 pkg/checksums/runner.go delete mode 100644 pkg/filesystem/filesystem.go create mode 100644 pkg/loldrivers/parse.go create mode 100644 pkg/options/options.go delete mode 100644 pkg/output/output.go create mode 100644 pkg/result/result.go diff --git a/cmd/loldrivers-client/loldrivers-client.go b/cmd/loldrivers-client/loldrivers-client.go index 178f7e0..0768c44 100644 --- a/cmd/loldrivers-client/loldrivers-client.go +++ b/cmd/loldrivers-client/loldrivers-client.go @@ -3,147 +3,54 @@ package main import ( - "flag" - "fmt" - "runtime" "sync" "time" "github.com/rtfmkiesel/loldrivers-client/pkg/checksums" - "github.com/rtfmkiesel/loldrivers-client/pkg/filesystem" "github.com/rtfmkiesel/loldrivers-client/pkg/logger" "github.com/rtfmkiesel/loldrivers-client/pkg/loldrivers" - "github.com/rtfmkiesel/loldrivers-client/pkg/output" + "github.com/rtfmkiesel/loldrivers-client/pkg/options" + "github.com/rtfmkiesel/loldrivers-client/pkg/result" ) func main() { - // To track execution time - startTime := time.Now() - - // Setup & parse command line arguments - var flagMode string - var flagDir string - var flagFileLimit int64 - var flagLocalFile string - var flagSilent bool - var flagJSON bool - var flagWorkers int - flag.StringVar(&flagMode, "m", "online", "") - flag.StringVar(&flagMode, "mode", "online", "") - flag.StringVar(&flagDir, "d", "", "") - flag.StringVar(&flagDir, "scan-dir", "", "") - flag.Int64Var(&flagFileLimit, "l", 10, "") - flag.Int64Var(&flagFileLimit, "scan-limit", 10, "") - flag.StringVar(&flagLocalFile, "f", "", "") - flag.StringVar(&flagLocalFile, "driver-file", "", "") - flag.BoolVar(&flagSilent, "s", false, "") - flag.BoolVar(&flagSilent, "silent", false, "") - flag.BoolVar(&flagJSON, "j", false, "") - flag.BoolVar(&flagJSON, "json", false, "") - flag.IntVar(&flagWorkers, "w", 20, "") - flag.IntVar(&flagWorkers, "workers", 20, "") - flag.Usage = func() { - logger.Banner() - fmt.Println(`Usage: - LOLDrivers-client.exe [OPTIONS] - -Options: - -m, --mode Operating Mode {online, local, internal} - online = Download the newest driver set (default) - local = Use a local drivers.json file (requires '-f') - internal = Use the built-in driver set (can be outdated) - - -d, --scan-dir Directory to scan for drivers (default: Windows driver folders) - Files which cannot be opened or read will be silently ignored - -l, --scan-limit Size limit for files to scan in MB (default: 10) - Be aware, higher values greatly increase runtime & CPU usage - - -f, --driver-file File path to 'drivers.json', when running with '-m local' - - -s, --silent Will only output found files for easy parsing (default: false) - -j, --json Format output as JSON (default: false) - - -w, --workers Number of "threads" to spawn (default: 20) - -h, --help Shows this text - `) - } - flag.Parse() - - // Only one output style - if flagSilent && flagJSON { - logger.Fatalf("only use '-s' or '-j', not both") - } else if flagSilent { - output.Mode = "silent" - logger.BeSilent = true - } else if flagJSON { - output.Mode = "json" - logger.BeSilent = true - } - - // ASCII L0VE - logger.Banner() - // Only run on Windows - if runtime.GOOS != "windows" { - logger.Fatalf("this client was made for Windows only") - } - - // Load the drivers - drivers, err := loldrivers.LoadDrivers(flagMode, flagLocalFile) + // Parse the command line options + opt, err := options.Parse() if err != nil { logger.Fatal(err) } - logger.Log(fmt.Sprintf("[+] Loaded %d drivers", len(drivers))) - // Get all hashes from the loaded drivers - driverHashes := loldrivers.GetHashes(drivers) - logger.Log(fmt.Sprintf(" |-- Got %d MD5 hashes", len(driverHashes.MD5Sums))) - logger.Log(fmt.Sprintf(" |-- Got %d SHA1 hashes", len(driverHashes.SHA1Sums))) - logger.Log(fmt.Sprintf(" |-- Got %d SHA256 hashes", len(driverHashes.SHA256Sums))) + // Load the drivers and their hashes + if err = loldrivers.LoadDrivers(opt.Mode, opt.LocalDriversPath); err != nil { + logger.Fatal(err) + } - // Create the channels and waitgroup for the checksum runners + // Set up the checksum runners chanFiles := make(chan string) - chanResults := make(chan output.Result) + chanResults := make(chan result.Result) wgRunner := new(sync.WaitGroup) - // Spawn the checksum runners - for i := 0; i <= flagWorkers; i++ { - go checksums.Runner(wgRunner, chanFiles, chanResults, driverHashes, drivers) + for i := 0; i <= opt.Workers; i++ { + go checksums.CalcRunner(wgRunner, chanFiles, chanResults) wgRunner.Add(1) } - // Create the waitgroup for the output runner - wgOutput := new(sync.WaitGroup) - // Spawn the output runner - go output.Runner(wgOutput, chanResults) - wgOutput.Add(1) - - // Set the folders which are going to be scanned for files - var paths []string - if flagDir == "" { - // User did not specify a path with '-d', use the default Windows paths - paths = append(paths, "C:\\Windows\\System32\\drivers") - paths = append(paths, "C:\\Windows\\System32\\DriverStore\\FileRepository") - paths = append(paths, "C:\\WINDOWS\\inf") - } else { - // User specified a custom folder to scan - paths = append(paths, flagDir) - } + // Set up the one output runner + wgResults := new(sync.WaitGroup) + go result.OutputRunner(wgResults, chanResults, opt.OutputMode) + wgResults.Add(1) // Get all files from subfolders and send them to the checksum runners via a channel - for _, path := range paths { - if err := filesystem.FileWalker(path, flagFileLimit, chanFiles); err != nil { + for _, path := range opt.ScanDirectories { + if err := checksums.FileWalker(path, opt.ScanSizeLimit, chanFiles); err != nil { logger.Fatal(err) } } - // Close the channel to start the checksum runners close(chanFiles) - // Wait here until all checksums are calculated and compared wgRunner.Wait() - // Close the results channel to process the results close(chanResults) - // Wait until all results have been processed - wgOutput.Wait() + wgResults.Wait() - logger.Log(fmt.Sprintf("[+] Done, took %s\n", time.Since(startTime))) + logger.Logf("[+] Done, took %s", time.Since(opt.StartTime)) } diff --git a/pkg/checksums/checksums.go b/pkg/checksums/checksums.go index 0ca933b..ca5f9b5 100644 --- a/pkg/checksums/checksums.go +++ b/pkg/checksums/checksums.go @@ -1,4 +1,3 @@ -// Package checksums calculates the MD5, SHA1 and SHA256 checksum of files package checksums import ( @@ -6,231 +5,57 @@ import ( "crypto/sha1" "crypto/sha256" "encoding/hex" - "fmt" "io" "os" - "strings" - "sync" - - "github.com/rtfmkiesel/loldrivers-client/pkg/filesystem" - "github.com/rtfmkiesel/loldrivers-client/pkg/logger" - "github.com/rtfmkiesel/loldrivers-client/pkg/loldrivers" - "github.com/rtfmkiesel/loldrivers-client/pkg/output" ) -// calcMD5() will return the MD5 checksum of the given file -func calcMD5(filepath string) (string, error) { - // Check if the file exists - if !filesystem.FileExists(filepath) { - return "", fmt.Errorf("file '%s' does not exist", filepath) - } - - // Open the file - file, err := os.Open(filepath) - // Check if the file can be accessed - if fileAccessErr(err) { - // No, skip file - return "", nil +// calcSHA256() will return the SHA256 checksum of filePath +func calcSHA256(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err } defer file.Close() - // Create a new MD5 - hash := md5.New() - - // Copy the file data into the hash + hash := sha256.New() if _, err := io.Copy(hash, file); err != nil { - errormsg := fmt.Sprintf("%s", err) - // Ignore read error "another process has locked a portion of the file" - if strings.Contains(strings.ToLower(errormsg), "has locked") { - return "", nil - } - return "", err } - // Get the checksum checksum := hash.Sum(nil) - - // Convert the checksum to a hex string return hex.EncodeToString(checksum), nil } -// calcSHA1 will return the SHA1 checksum of the given file -func calcSHA1(filepath string) (string, error) { - // Check if the file exists - if !filesystem.FileExists(filepath) { - return "", fmt.Errorf("file '%s' does not exist", filepath) - } - - // Open the file - file, err := os.Open(filepath) - // Check if the file can be accessed - if fileAccessErr(err) { - // No, skip file - return "", nil +// calcSHA1() will return the SHA1 checksum of filePath +func calcSHA1(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err } defer file.Close() - // Create a new SHA1 hash := sha1.New() - - // Copy the file data into the hash if _, err := io.Copy(hash, file); err != nil { - errormsg := fmt.Sprintf("%s", err) - // Ignore read error "another process has locked a portion of the file" - if strings.Contains(strings.ToLower(errormsg), "has locked") { - return "", nil - } - return "", err } - // Get the checksum checksum := hash.Sum(nil) - - // Convert the checksum to a hex string return hex.EncodeToString(checksum), nil } -// calcSHA256 will return the SHA256 checksum of the given file -func calcSHA256(filepath string) (string, error) { - // Check if the file exists - if !filesystem.FileExists(filepath) { - return "", fmt.Errorf("file '%s' does not exist", filepath) - } - - // Open the file - file, err := os.Open(filepath) - // Check if the file can be accessed - if fileAccessErr(err) { - // No, skip file - return "", nil +// calcMD5() will return the MD5 checksum of filePath +func calcMD5(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err } defer file.Close() - // Create a new SHA256 - hash := sha256.New() - - // Copy the file data into the hash + hash := md5.New() if _, err := io.Copy(hash, file); err != nil { - errormsg := fmt.Sprintf("%s", err) - // Ignore read error "another process has locked a portion of the file" - if strings.Contains(strings.ToLower(errormsg), "has locked") { - return "", nil - } - return "", err } - // Get the checksum checksum := hash.Sum(nil) - - // Convert the checksum to a hex string return hex.EncodeToString(checksum), nil } - -// contains() will return true if a []string contains a string -func contains(slice []string, value string) bool { - for _, s := range slice { - if s == value { - return true - } - } - - return false -} - -// fileAccessErr() will handle ACL errors. -// It will return true if a file should be skipped (can't be read/opened) -func fileAccessErr(err error) bool { - if err == nil { - return false - } - - errormsg := fmt.Sprintf("%s", err) - // Ignore open errors "cannot access the file", "file cannot be accessed", "Access denied" - if strings.Contains(strings.ToLower(errormsg), "access") { - return true - } - - // Ignore "file does not exist" error because files could have been removed in the meantime - // os.IsExist does not work - if strings.Contains(strings.ToLower(errormsg), "does not exist") { - return true - } - - return false -} - -// Runner() is used as a go func for calculating and comparing file checksums -// from a job channel of filenames. If a calculated checksum matches a loaded checksum -// a result in the form of logger.Result will be sent to an output channel -func Runner(wg *sync.WaitGroup, chanJobs <-chan string, chanResults chan<- output.Result, checksums loldrivers.DriverHashes, drivers []loldrivers.Driver) { - defer wg.Done() - - // For each job - for job := range chanJobs { - // Calculate the MD5 - MD5, err := calcMD5(job) - if err != nil { - logger.Error(err) - continue - } - // Check if the MD5 in in the driver slice - if contains(checksums.MD5Sums, MD5) { - // Find the matching driver - // Error ignored since there must be a match - driver, _ := loldrivers.MatchHash(MD5, drivers) - - // Send result to the output channel - chanResults <- output.Result{ - Filepath: job, - Checksum: MD5, - Driver: driver, - } - continue - } - - // Calculate the SHA1 - SHA1, err := calcSHA1(job) - if err != nil { - logger.Error(err) - continue - } - // Check if the SHA1 in in the driver slice - if contains(checksums.SHA1Sums, SHA1) { - // Find the matching driver - // Error ignored since there must be a match - driver, _ := loldrivers.MatchHash(SHA1, drivers) - - // Send result to the output channel - chanResults <- output.Result{ - Filepath: job, - Checksum: SHA1, - Driver: driver, - } - continue - } - - // Calculate the SHA256 - SHA256, err := calcSHA256(job) - if err != nil { - logger.Error(err) - continue - } - // Check if the SHA256 in in the driver slice - if contains(checksums.SHA256Sums, SHA256) { - // Find the matching driver - // Error ignored since there must be a match - driver, _ := loldrivers.MatchHash(SHA256, drivers) - - // Send result to the output channel - chanResults <- output.Result{ - Filepath: job, - Checksum: SHA256, - Driver: driver, - } - continue - } - } -} diff --git a/pkg/checksums/filewalker.go b/pkg/checksums/filewalker.go new file mode 100644 index 0000000..dbb4072 --- /dev/null +++ b/pkg/checksums/filewalker.go @@ -0,0 +1,55 @@ +package checksums + +import ( + "os" + "path/filepath" + + "github.com/rtfmkiesel/loldrivers-client/pkg/logger" +) + +// checksums.FileWalker() will recursively send files from path, who are smaller than sizeLimit, to outputChannel +func FileWalker(path string, sizeLimit int64, outputChannel chan<- string) (err error) { + logger.Logf("[*] Searching for files in %s", path) + + // Walk over every file in a given folder + err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + if err != nil { + // Ignore a file if we get "Access is denied" error + if os.IsPermission(err) { + return nil + } + + // Ignore a file if we get "The system cannot find the file specified" error + if os.IsNotExist(err) { + return nil + } + + return err + } + + // Skip directories and non regular files + if info.IsDir() || !info.Mode().IsRegular() { + return nil + } + + // Skip files that can't be read + if info.Mode().Perm()&0400 == 0 { + return nil + } + + // Skip files larger than the specified size limit + if info.Size() > sizeLimit*1024*1024 { + return nil + } + + // Send to the channel + outputChannel <- path + return nil + }) + + if err != nil { + return err + } + + return nil +} diff --git a/pkg/checksums/runner.go b/pkg/checksums/runner.go new file mode 100644 index 0000000..d964f18 --- /dev/null +++ b/pkg/checksums/runner.go @@ -0,0 +1,59 @@ +package checksums + +import ( + "sync" + + "github.com/rtfmkiesel/loldrivers-client/pkg/logger" + "github.com/rtfmkiesel/loldrivers-client/pkg/loldrivers" + "github.com/rtfmkiesel/loldrivers-client/pkg/result" +) + +// checksums.CalcRunner() is used as a go func for calculating and comparing file checksums +// from a chanJobs. If a calculated checksum matches a loaded checksum a result.Result will be sent to chanResults +func CalcRunner(wg *sync.WaitGroup, chanJobs <-chan string, chanResults chan<- result.Result) { + defer wg.Done() + + for job := range chanJobs { + // SHA256 + sha256, err := calcSHA256(job) + if err != nil { + logger.Error(err) + } else if driver := loldrivers.MatchHash(sha256); driver != nil { + chanResults <- result.Result{ + Filepath: job, + Checksum: sha256, + Driver: *driver, + } + + continue + } + + // SHA1 + sha1, err := calcSHA1(job) + if err != nil { + logger.Error(err) + } else if driver := loldrivers.MatchHash(sha1); driver != nil { + chanResults <- result.Result{ + Filepath: job, + Checksum: sha1, + Driver: *driver, + } + + continue + } + + // MD5 + md5, err := calcMD5(job) + if err != nil { + logger.Error(err) + } else if driver := loldrivers.MatchHash(md5); driver != nil { + chanResults <- result.Result{ + Filepath: job, + Checksum: md5, + Driver: *driver, + } + + continue + } + } +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go deleted file mode 100644 index 18b4f60..0000000 --- a/pkg/filesystem/filesystem.go +++ /dev/null @@ -1,93 +0,0 @@ -// Package filesystem handles filesystem operations -package filesystem - -import ( - "fmt" - "io" - "os" - "path/filepath" - - "github.com/rtfmkiesel/loldrivers-client/pkg/logger" -) - -// FileExists() will return true if a given file exists -func FileExists(filepath string) bool { - if _, err := os.Stat(filepath); os.IsNotExist(err) { - return false - } else { - return true - } -} - -// FileRead() will return the contents of a file as bytes -func FileRead(filepath string) (contentBytes []byte, err error) { - // Check if file exists - if !FileExists(filepath) { - return nil, fmt.Errorf("file '%s' does not exist", filepath) - } - - // Open file - file, err := os.Open(filepath) - if err != nil { - return nil, fmt.Errorf("could not open file '%s'", filepath) - } - defer file.Close() - - // Read file - contentBytes, err = io.ReadAll(file) - if err != nil { - return nil, fmt.Errorf("could not read file '%s'", filepath) - } - - return contentBytes, nil -} - -// FileWalker() will recursively send files from a directory, who are smaller than the -// specified size limit, to a string channel -// -// sizeLimit as int64 in MB (ex: 5) -func FileWalker(path string, sizeLimit int64, outputChannel chan<- string) (err error) { - logger.Log(fmt.Sprintf("[*] Searching for files in %s", path)) - - // Walk over every file in a given folder - err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { - if err != nil { - // Ignore a file if we get "Access is denied" error - if os.IsPermission(err) { - return nil - } - - // Ignore a file if we get "The system cannot find the file specified" error - if os.IsNotExist(err) { - return nil - } - - return err - } - - // Skip directories and non regular files - if info.IsDir() || !info.Mode().IsRegular() { - return nil - } - - // Skip files that can't be read - if info.Mode().Perm()&0400 == 0 { - return nil - } - - // Skip files larger than the specified size limit - if info.Size() > sizeLimit*1024*1024 { - return nil - } - - // Send to the channel - outputChannel <- path - return nil - }) - - if err != nil { - return err - } - - return nil -} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 756eb17..ae5856f 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -1,4 +1,3 @@ -// Package logger handles errors as well as the output handling package logger import ( @@ -36,26 +35,8 @@ func Error(err error) { fmt.Fprintf(os.Stderr, "[!] ERROR: %s\n", err) } -func Errorf(msg string, args ...interface{}) { - msg = fmt.Sprintf("[!] ERROR: "+msg, args...) - if strings.HasSuffix(msg, "\n") { - fmt.Fprint(os.Stderr, msg) - } else { - fmt.Fprint(os.Stderr, msg+"\n") - } -} - func Fatal(err error) { fmt.Fprintf(os.Stderr, "[!] ERROR: %s\n", err) -} - -func Fatalf(msg string, args ...interface{}) { - msg = fmt.Sprintf("[!] ERROR: "+msg, args...) - if strings.HasSuffix(msg, "\n") { - fmt.Fprint(os.Stderr, msg) - } else { - fmt.Fprint(os.Stderr, msg+"\n") - } os.Exit(1) } diff --git a/pkg/loldrivers/loldrivers.go b/pkg/loldrivers/loldrivers.go index c9610f0..01a8d90 100644 --- a/pkg/loldrivers/loldrivers.go +++ b/pkg/loldrivers/loldrivers.go @@ -1,291 +1,125 @@ -// Package loldrivers handles the JSON data from loldrivers.io package loldrivers import ( _ "embed" - "encoding/json" "fmt" "io" - "net/http" + "os" - "github.com/rtfmkiesel/loldrivers-client/pkg/filesystem" "github.com/rtfmkiesel/loldrivers-client/pkg/logger" ) var ( - // Embed a driver.json during build for use with -m 'internal' - //go:embed drivers.json - internalDrivers []byte + LoadedDrivers []Driver + LoadedHashes DriverHashes ) -const ( - // Download link to the 'drivers.json' file - apiURL = "https://www.loldrivers.io/api/drivers.json" -) - -// Struct for a single driver from loldrivers.io -// -// Based on the the JSON spec from -// https://github.com/magicsword-io/LOLDrivers/blob/validate/bin/spec/drivers.spec.json -type Driver struct { - ID string `json:"Id"` - Author string `json:"Author"` - Created string `json:"Created"` - MitreID string `json:"MitreID"` - Category string `json:"Category"` - Verified string `json:"Verified"` - Commands unmarshalCommands `json:"Commands,omitempty"` - Resources []string `json:"Resources,omitempty"` - Acknowledgement struct { - Person unmarshalStringOrStringArray `json:"Person"` - Handle string `json:"Handle"` - } `json:"Acknowledgement,omitempty"` - Detection []struct { - Type string `json:"type"` - Value string `json:"value"` - } `json:"Detection,omitempty"` - KnownVulnerableSamples []struct { - Filename string `json:"Filename"` - MD5 string `json:"MD5,omitempty"` - SHA1 string `json:"SHA1,omitempty"` - SHA256 string `json:"SHA256,omitempty"` - } `json:"KnownVulnerableSamples,omitempty"` - Tags []string `json:"Tags"` -} - -// 'Command' struct for a driver from loldrivers.io -// -// Based on the the JSON spec from -// https://github.com/magicsword-io/LOLDrivers/blob/validate/bin/spec/drivers.spec.json -type Command struct { - Command string `json:"Command"` - Description string `json:"Description"` - Usecase string `json:"Usecase"` - Privileges string `json:"Privileges"` - OperatingSystem string `json:"OperatingSystem"` -} - -// Struct to store the driver hashes from loldrivers.io -type DriverHashes struct { - MD5Sums []string - SHA1Sums []string - SHA256Sums []string -} - -// Struct that is used during unmarshalling of the "Commands" JSON data -// since sometimes it'll be either a single string or a "Command" struct -type unmarshalCommands struct { - Value []Command - Set bool -} - -// Struct that is used during unmarshalling of various the JSON data -// since sometimes a key can be either a single string or an array of strings -type unmarshalStringOrStringArray struct { - Value []string - Set bool -} - -// The UnmarshalJSON method on UnmarshalCommands will parse the JSON -// as eiter a "Command" struct or a single string (into a "Command" struct) -func (s *unmarshalCommands) UnmarshalJSON(b []byte) error { - var strVal string - var cmdVal Command - // Try to unmarshal into a string first - err := json.Unmarshal(b, &strVal) - if err == nil { - // No error, set string value, leave rest empty - cmdVal = Command{ - Command: strVal, - } - } else { - // Try to unmarshall into a "Command" struct - err = json.Unmarshal(b, &cmdVal) - if err != nil { - // Both unmarshall were unsuccessful - return err - } - } - // Set the value of s to the unmarshalled value - s.Value = append(s.Value, cmdVal) - s.Set = true - return nil -} - -// The UnmarshalJSON method will parse the JSON as either a single string -// or an array of strings into a slice of strings -func (s *unmarshalStringOrStringArray) UnmarshalJSON(b []byte) error { - var strVal string - var arrVal []string - // Try to unmarshal into a single string first - err := json.Unmarshal(b, &strVal) - if err == nil { - // No error, create a array with a single value - arrVal = []string{strVal} - } else { - // Try to unmarshall into a slice of strings - err = json.Unmarshal(b, &arrVal) - if err != nil { - // Both unmarshall were unsuccessful - return err - } - } - // Set the value of s to the unmarshalled value which will always be a slice - s.Value = arrVal - s.Set = true - return nil -} - // LoadDrivers() will load the drivers based on the selected mode -// and return a slice of loldrivers.Driver // // mode = online, local, internal // -// filepath = path to JSON file if mode == local, else "" -func LoadDrivers(mode string, filepath string) (drivers []Driver, err error) { - // Load the drivers based on the selected mode - logger.Log(fmt.Sprintf("[*] Loading drivers with mode '%s'", mode)) - +// filePath = path to JSON file if mode == ModeLocal +func LoadDrivers(mode string, filePath string) (err error) { switch mode { case "online": - // Default, download from the web - // Download drivers - drivers, err = download() + jsonBytes, err := download() if err != nil { - // There was a parsing error logger.Error(err) - logger.Log("[!] Got an error while parsing online data. Falling back to internal data set") - drivers, err = parse(internalDrivers) + logger.Log("[*] Got an error while downloading data. Falling back to internal data set") + LoadedDrivers, err = parse(internalDrivers) + if err != nil { + return err + } + } else { + LoadedDrivers, err = parse(jsonBytes) if err != nil { - return drivers, err + logger.Error(err) + logger.Log("[*] Got an error while parsing data. Falling back to internal data set") + LoadedDrivers, err = parse(internalDrivers) + if err != nil { + return err + } } } case "local": - // User wants to use a local file - if filepath == "" { - logger.Fatalf("mode 'local' requires '-f'") + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("could not open file '%s'", filePath) } + defer file.Close() - // Read file - jsonBytes, err := filesystem.FileRead(filepath) + content, err := io.ReadAll(file) if err != nil { - return drivers, err + return fmt.Errorf("could not read file '%s'", filePath) } - // Parse file - drivers, err = parse(jsonBytes) + LoadedDrivers, err = parse(content) if err != nil { - // There was a parsing error logger.Error(err) - logger.Log("[!] Got an error while parsing local file. Falling back to internal data set") - drivers, err = parse(internalDrivers) + logger.Logf("[*] Got an error while parsing '%s'. Falling back to internal data set", filePath) + LoadedDrivers, err = parse(internalDrivers) if err != nil { - return drivers, err + return err } } case "internal": - // User wants to use internal data set - // Parse bytes - drivers, err = parse(internalDrivers) + LoadedDrivers, err = parse(internalDrivers) if err != nil { - return drivers, err + return err } default: - logger.Fatalf("invalid mode '%s'", mode) + return fmt.Errorf("invalid mode '%s'", mode) } - return drivers, nil -} + logger.Logf("[+] Loaded %d drivers", len(LoadedDrivers)) -// GetHashes() will return loldrivers.DriverHashes containing all -// MD5, SHA1 and SHA256 from a slice of loldrivers.Driver. Empty values or '-' will be ignored -func GetHashes(drivers []Driver) (driverHashes DriverHashes) { - // Get all checksums from the loaded drivers - for _, driver := range drivers { + for _, driver := range LoadedDrivers { for _, knownVulnSample := range driver.KnownVulnerableSamples { // Append MD5 if exist if knownVulnSample.MD5 != "" && knownVulnSample.MD5 != "-" { - driverHashes.MD5Sums = append(driverHashes.MD5Sums, knownVulnSample.MD5) + LoadedHashes.MD5Sums = append(LoadedHashes.MD5Sums, knownVulnSample.MD5) } // Append SHA1 if exist if knownVulnSample.SHA1 != "" && knownVulnSample.SHA1 != "-" { - driverHashes.SHA1Sums = append(driverHashes.SHA1Sums, knownVulnSample.SHA1) + LoadedHashes.SHA1Sums = append(LoadedHashes.SHA1Sums, knownVulnSample.SHA1) } // Append SHA256 if exist if knownVulnSample.SHA256 != "" && knownVulnSample.SHA256 != "-" { - driverHashes.SHA256Sums = append(driverHashes.SHA256Sums, knownVulnSample.SHA256) + LoadedHashes.SHA256Sums = append(LoadedHashes.SHA256Sums, knownVulnSample.SHA256) } } } - return driverHashes + logger.Logf(" |--> %d MD5 hashes", len(LoadedHashes.MD5Sums)) + logger.Logf(" |--> %d SHA1 hashes", len(LoadedHashes.SHA1Sums)) + logger.Logf(" |--> %d SHA256 hashes", len(LoadedHashes.SHA256Sums)) + + return nil } // MatchHash() will return the matching loldrivers.Driver for a given hash or else will return an error -func MatchHash(hash string, drivers []Driver) (match Driver, err error) { - // Get all checksums from the loaded drivers - for _, driver := range drivers { +func MatchHash(hash string) (match *Driver) { + for _, driver := range LoadedDrivers { for _, knownVulnSample := range driver.KnownVulnerableSamples { - if knownVulnSample.MD5 == hash { - return driver, nil - } - if knownVulnSample.SHA1 == hash { - return driver, nil - } - if knownVulnSample.SHA256 == hash { - return driver, nil + switch len(hash) { + case 32: + if knownVulnSample.MD5 == hash { + return &driver + } + case 40: + if knownVulnSample.SHA1 == hash { + return &driver + } + case 64: + if knownVulnSample.SHA256 == hash { + return &driver + } } } } - return match, fmt.Errorf("no match found") -} - -// download() will download the current loldrivers.io data set -func download() (drivers []Driver, err error) { - logger.Log("[*] Downloading the newest drivers") - - // Setup HTTP client - client := &http.Client{} - - // Build request - request, err := http.NewRequest("GET", apiURL, nil) - if err != nil { - return nil, err - } - - // Make the request - response, err := client.Do(request) - if err != nil { - return nil, err - } - defer response.Body.Close() - logger.Log("[+] Download successful") - - // Read the bode into []byte - jsonBytes, err := io.ReadAll(response.Body) - if err != nil { - return nil, err - } - - // Parse the data - drivers, err = parse(jsonBytes) - if err != nil { - return nil, err - } - - return drivers, nil -} - -// parse() will return a slice of loldrivers.Drivers from JSON input bytes -func parse(jsonBytes []byte) (drivers []Driver, err error) { - // Unmarshal JSON data - if err := json.Unmarshal(jsonBytes, &drivers); err != nil { - return nil, err - } - - return drivers, nil + return nil } diff --git a/pkg/loldrivers/parse.go b/pkg/loldrivers/parse.go new file mode 100644 index 0000000..3ca344f --- /dev/null +++ b/pkg/loldrivers/parse.go @@ -0,0 +1,169 @@ +package loldrivers + +import ( + _ "embed" + "encoding/json" + "io" + "net/http" + + "github.com/rtfmkiesel/loldrivers-client/pkg/logger" +) + +var ( + // Embed a driver.json during build for use with -m 'internal' + //go:embed drivers.json + internalDrivers []byte +) + +const ( + // Download Url + apiUrl = "https://www.loldrivers.io/api/drivers.json" +) + +// Struct for a single driver from loldrivers.io +// +// Based on the the JSON spec from +// https://github.com/magicsword-io/LOLDrivers/blob/validate/bin/spec/drivers.spec.json +type Driver struct { + ID string `json:"Id"` + Author string `json:"Author"` + Created string `json:"Created"` + MitreID string `json:"MitreID"` + Category string `json:"Category"` + Verified string `json:"Verified"` + Commands unmarshalCommands `json:"Commands,omitempty"` + Resources []string `json:"Resources,omitempty"` + Acknowledgement struct { + Person unmarshalStringOrStringArray `json:"Person"` + Handle string `json:"Handle"` + } `json:"Acknowledgement,omitempty"` + Detection []struct { + Type string `json:"type"` + Value string `json:"value"` + } `json:"Detection,omitempty"` + KnownVulnerableSamples []struct { + Filename string `json:"Filename"` + MD5 string `json:"MD5,omitempty"` + SHA1 string `json:"SHA1,omitempty"` + SHA256 string `json:"SHA256,omitempty"` + } `json:"KnownVulnerableSamples,omitempty"` + Tags []string `json:"Tags"` +} + +// 'Command' struct for a driver from loldrivers.io +// +// Based on the the JSON spec from +// https://github.com/magicsword-io/LOLDrivers/blob/validate/bin/spec/drivers.spec.json +type Command struct { + Command string `json:"Command"` + Description string `json:"Description"` + Usecase string `json:"Usecase"` + Privileges string `json:"Privileges"` + OperatingSystem string `json:"OperatingSystem"` +} + +// Struct to store the driver hashes from loldrivers.io +type DriverHashes struct { + MD5Sums []string + SHA1Sums []string + SHA256Sums []string +} + +// Struct that is used during unmarshalling of the "Commands" JSON data +// since sometimes it'll be either a single string or a "Command" struct +type unmarshalCommands struct { + Value []Command + Set bool +} + +// Struct that is used during unmarshalling of various the JSON data +// since sometimes a key can be either a single string or an array of strings +type unmarshalStringOrStringArray struct { + Value []string + Set bool +} + +// The UnmarshalJSON method on UnmarshalCommands will parse the JSON +// as eiter a "Command" struct or a single string (into a "Command" struct) +func (s *unmarshalCommands) UnmarshalJSON(b []byte) error { + var strVal string + var cmdVal Command + // Try to unmarshal into a string first + err := json.Unmarshal(b, &strVal) + if err == nil { + // No error, set string value, leave rest empty + cmdVal = Command{ + Command: strVal, + } + } else { + // Try to unmarshall into a "Command" struct + err = json.Unmarshal(b, &cmdVal) + if err != nil { + // Both unmarshall were unsuccessful + return err + } + } + // Set the value of s to the unmarshalled value + s.Value = append(s.Value, cmdVal) + s.Set = true + return nil +} + +// The UnmarshalJSON method will parse the JSON as either a single string +// or an array of strings into a slice of strings +func (s *unmarshalStringOrStringArray) UnmarshalJSON(b []byte) error { + var strVal string + var arrVal []string + // Try to unmarshal into a single string first + err := json.Unmarshal(b, &strVal) + if err == nil { + // No error, create a array with a single value + arrVal = []string{strVal} + } else { + // Try to unmarshall into a slice of strings + err = json.Unmarshal(b, &arrVal) + if err != nil { + // Both unmarshall were unsuccessful + return err + } + } + // Set the value of s to the unmarshalled value which will always be a slice + s.Value = arrVal + s.Set = true + return nil +} + +// download() will download the current loldrivers.io data set as []byte +func download() ([]byte, error) { + logger.Log("[*] Downloading the newest drivers") + + client := &http.Client{} + request, err := http.NewRequest("GET", apiUrl, nil) + if err != nil { + return nil, err + } + + request.Header.Set("User-Agent", "LOLDrivers-client") + + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + jsonBytes, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return jsonBytes, nil +} + +// parse() will return a slice of loldrivers.Drivers from JSON input bytes +func parse(jsonBytes []byte) (drivers []Driver, err error) { + if err := json.Unmarshal(jsonBytes, &drivers); err != nil { + return nil, err + } + + return drivers, nil +} diff --git a/pkg/loldrivers/parse_test.go b/pkg/loldrivers/parse_test.go index d16c830..1c8b063 100644 --- a/pkg/loldrivers/parse_test.go +++ b/pkg/loldrivers/parse_test.go @@ -4,7 +4,12 @@ import "testing" // online func TestOnlineParse(t *testing.T) { - _, err := download() + jsonBytes, err := download() + if err != nil { + t.Error(err) + } + + _, err = parse(jsonBytes) if err != nil { t.Error(err) } diff --git a/pkg/options/options.go b/pkg/options/options.go new file mode 100644 index 0000000..5665053 --- /dev/null +++ b/pkg/options/options.go @@ -0,0 +1,110 @@ +package options + +import ( + "flag" + "fmt" + "time" + + "github.com/rtfmkiesel/loldrivers-client/pkg/logger" +) + +type Options struct { + Mode string + LocalDriversPath string + ScanDirectories []string + ScanSizeLimit int64 + OutputMode string + Workers int + StartTime time.Time +} + +// Parse the command line options into an Options struct +func Parse() (opt *Options, err error) { + opt = &Options{} + + opt.StartTime = time.Now() + + var flagDir string + var flagSilent bool + var flagJson bool + flag.StringVar(&opt.Mode, "m", "online", "") + flag.StringVar(&opt.Mode, "mode", "online", "") + flag.StringVar(&opt.LocalDriversPath, "f", "", "") + flag.StringVar(&opt.LocalDriversPath, "driver-file", "", "") + flag.StringVar(&flagDir, "d", "", "") + flag.StringVar(&flagDir, "scan-dir", "", "") + flag.Int64Var(&opt.ScanSizeLimit, "l", 10, "") + flag.Int64Var(&opt.ScanSizeLimit, "scan-limit", 10, "") + flag.BoolVar(&flagSilent, "s", false, "") + flag.BoolVar(&flagSilent, "silent", false, "") + flag.BoolVar(&flagJson, "j", false, "") + flag.BoolVar(&flagJson, "json", false, "") + flag.IntVar(&opt.Workers, "w", 20, "") + flag.IntVar(&opt.Workers, "workers", 20, "") + flag.Usage = func() { usage() } + flag.Parse() + + switch opt.Mode { + case "online", "internal": + // we good + case "local": + if opt.Mode == "local" && opt.LocalDriversPath == "" { + return nil, fmt.Errorf("mode 'local' requires '-f'") + } + default: + return nil, fmt.Errorf("invalid mode '%s'", opt.Mode) + } + + // Only one output style + if flagSilent && flagJson { + return nil, fmt.Errorf("only use '-s' or '-j', not both") + } else if flagSilent { + opt.OutputMode = "silent" + logger.BeSilent = true + } else if flagJson { + opt.OutputMode = "json" + logger.BeSilent = true + } + + logger.Banner() + + // Directories + if flagDir == "" { + // User did not specify a path with '-d', use the default Windows opt.Directories + opt.ScanDirectories = append(opt.ScanDirectories, "C:\\Windows\\System32\\drivers") + opt.ScanDirectories = append(opt.ScanDirectories, "C:\\Windows\\System32\\DriverStore\\FileRepository") + opt.ScanDirectories = append(opt.ScanDirectories, "C:\\WINDOWS\\inf") + } else { + // User specified a custom folder to scan + opt.ScanDirectories = append(opt.ScanDirectories, flagDir) + } + + return opt, nil +} + +func usage() { + logger.Banner() + fmt.Println(`Usage: + LOLDrivers-client.exe [OPTIONS] + +Options: + -m, --mode Operating Mode {online, local, internal} + online = Download the newest driver set (default) + local = Use a local drivers.json file (requires '-f') + internal = Use the built-in driver set (can be outdated) + + -f, --driver-file File path to 'drivers.json', when running in local mode + + -d, --scan-dir Directory to scan for drivers (default: Windows driver folders) + Files which cannot be opened or read will be silently ignored + + -l, --scan-limit Size limit for files to scan in MB (default: 10) + Be aware, higher values greatly increase runtime & CPU usage + + -s, --silent Will only output found files for easy parsing (default: false) + -j, --json Format output as JSON (default: false) + + -w, --workers Number of "threads" to spawn (default: 20) + -h, --help Shows this text + `) +} diff --git a/pkg/output/output.go b/pkg/output/output.go deleted file mode 100644 index a0a0b93..0000000 --- a/pkg/output/output.go +++ /dev/null @@ -1,58 +0,0 @@ -// package output handles the printing of the results to the terminal -package output - -import ( - "encoding/json" - "fmt" - "sync" - - "github.com/rtfmkiesel/loldrivers-client/pkg/logger" - "github.com/rtfmkiesel/loldrivers-client/pkg/loldrivers" -) - -var ( - Mode = "default" -) - -// Struct for the results of the compare runners -type Result struct { - Filepath string - Checksum string - Driver loldrivers.Driver -} - -// Runner() is used as a go func to display the results -func Runner(wg *sync.WaitGroup, chanResults <-chan Result) { - defer wg.Done() - - // To count how many results we got - counter := 0 - - // For each result - for result := range chanResults { - // Print result based on output - switch Mode { - case "silent": - fmt.Printf("%s\n", result.Filepath) - case "json": - jsonOutput, err := json.Marshal(result) - if err != nil { - logger.Fatal(err) - } - fmt.Printf("%s\n", string(jsonOutput)) - default: - fmt.Printf("[!] MATCH: %s\n", result.Filepath) - //fmt.Printf(" |-- Category: %s\n", result.Driver.Category) - fmt.Printf(" |-- Checksum: %s\n", result.Checksum) - fmt.Printf(" |-- Link: https://loldrivers.io/drivers/%s\n", result.Driver.ID) - } - - counter++ - } - - if counter == 0 { - logger.Log("[-] No vulnerable or malicious driver(s) found!") - } else { - logger.Log(fmt.Sprintf("[+] Found a total of %d vulnerable or malicious driver(s)!", counter)) - } -} diff --git a/pkg/result/result.go b/pkg/result/result.go new file mode 100644 index 0000000..3d30078 --- /dev/null +++ b/pkg/result/result.go @@ -0,0 +1,52 @@ +package result + +import ( + "encoding/json" + "fmt" + "os" + "sync" + + "github.com/rtfmkiesel/loldrivers-client/pkg/logger" + "github.com/rtfmkiesel/loldrivers-client/pkg/loldrivers" +) + +// Struct for the results of the compare runners +type Result struct { + Filepath string + Checksum string + Driver loldrivers.Driver +} + +// OutputRunner() is used as a go func to display the results +func OutputRunner(wg *sync.WaitGroup, chanResults <-chan Result, mode string) { + defer wg.Done() + + // To count how many results we got + counter := 0 + + for result := range chanResults { + switch mode { + case "silent": + fmt.Fprintf(os.Stdout, "%s\n", result.Filepath) + case "json": + jsonOutput, err := json.Marshal(result) + if err != nil { + logger.Fatal(err) + } + fmt.Fprintf(os.Stdout, "%s\n", string(jsonOutput)) + default: + fmt.Fprintf(os.Stdout, "[!] Found %s\n", result.Driver.Category) + fmt.Fprintf(os.Stdout, " |--> Path: %s\n", result.Filepath) + //fmt.Fprintf(os.Stdout, " |--> Checksum: %s\n", result.Checksum) + fmt.Fprintf(os.Stdout, " |--> Link: https://loldrivers.io/drivers/%s\n", result.Driver.ID) + } + + counter++ + } + + if counter == 0 { + logger.Log("[-] No vulnerable or malicious driver(s) found!") + } else { + logger.Logf("[+] Found a total of %d vulnerable or malicious driver(s)!", counter) + } +}