Skip to content

Commit

Permalink
refactor(db): return error and use pipeline in Redis GetMulti (#63)
Browse files Browse the repository at this point in the history
* refactor(db): return error and add GetMulti in server mode

* refactor(redis): use pipeline in GetMulti

* chore: fix file path

* chore: change error handling

* feat(fetcher/exploitdb): do not return nil

* chore(rdb): handle ErrRecordNotFound
  • Loading branch information
MaineK00n authored Sep 30, 2021
1 parent 71f38da commit 10b78a3
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 172 deletions.
2 changes: 2 additions & 0 deletions GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ fetch-redis:

diff-cveid:
@ python integration/diff_server_mode.py cveid --sample_rate 0.01
@ python integration/diff_server_mode.py cveids --sample_rate 0.01

diff-uniqueid:
@ python integration/diff_server_mode.py uniqueid --sample_rate 0.01
@ python integration/diff_server_mode.py uniqueids --sample_rate 0.01

diff-server-rdb:
integration/exploitdb.old server --dbpath=$(PWD)/integration/go-exploitdb.old.sqlite3 --port 1325 > /dev/null 2>&1 &
Expand Down
12 changes: 9 additions & 3 deletions commands/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func init() {
viper.SetDefault("param", "")
}

func searchExploit(cmd *cobra.Command, args []string) (err error) {
func searchExploit(cmd *cobra.Command, args []string) error {
if err := util.SetLogger(viper.GetBool("log-to-file"), viper.GetString("log-dir"), viper.GetBool("debug"), viper.GetBool("log-json")); err != nil {
return xerrors.Errorf("Failed to SetLogger. err: %w", err)
}
Expand All @@ -66,13 +66,19 @@ func searchExploit(cmd *cobra.Command, args []string) (err error) {
log15.Error("Specify the search type [CVE] parameters like `--param CVE-xxxx-xxxx`")
return errors.New("Invalid CVE Param")
}
results = driver.GetExploitByCveID(param)
results, err = driver.GetExploitByCveID(param)
if err != nil {
return err
}
case "ID":
if !exploitDBIDRegexp.MatchString(param) {
log15.Error("Specify the search type [ID] parameters like `--param 10000`")
return errors.New("Invalid ID Param")
}
results = driver.GetExploitByID(param)
results, err = driver.GetExploitByID(param)
if err != nil {
return err
}
default:
log15.Error("Specify the search type [ CVE / ID].")
return errors.New("Invalid Type")
Expand Down
9 changes: 5 additions & 4 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ type DB interface {
OpenDB(dbType, dbPath string, debugSQL bool) (bool, error)
CloseDB() error
MigrateDB() error
GetExploitByID(string) []models.Exploit
GetExploitByCveID(string) []models.Exploit
GetExploitMultiByCveID([]string) map[string][]models.Exploit
GetExploitByID(string) ([]models.Exploit, error)
GetExploitMultiByID([]string) (map[string][]models.Exploit, error)
GetExploitByCveID(string) ([]models.Exploit, error)
GetExploitMultiByCveID([]string) (map[string][]models.Exploit, error)
InsertExploit(models.ExploitType, []models.Exploit) error
GetExploitAll() []models.Exploit
GetExploitAll() ([]models.Exploit, error)

IsExploitModelV1() (bool, error)
GetFetchMeta() (*models.FetchMeta, error)
Expand Down
116 changes: 72 additions & 44 deletions db/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/spf13/viper"
"github.com/vulsio/go-exploitdb/config"
"github.com/vulsio/go-exploitdb/models"
"github.com/vulsio/go-exploitdb/util"
"golang.org/x/xerrors"

"gorm.io/driver/mysql"
Expand Down Expand Up @@ -149,12 +148,11 @@ func (r *RDBDriver) deleteAndInsertExploit(exploitType models.ExploitType, explo
}

oldIDs := []int64{}
result := tx.Model(&models.Exploit{}).Select("id").Where("exploit_type = ?", exploitType).Find(&oldIDs)
if result.Error != nil && !errors.Is(result.Error, gorm.ErrRecordNotFound) {
return xerrors.Errorf("Failed to select old defs: %w", result.Error)
if err := tx.Model(&models.Exploit{}).Select("id").Where("exploit_type = ?", exploitType).Find(&oldIDs).Error; err != nil {
return xerrors.Errorf("Failed to select old defs: %w", err)
}

if result.RowsAffected > 0 {
if len(oldIDs) > 0 {
log15.Info("Deleting old Exploits")
bar := pb.StartNew(len(oldIDs))
for idx := range chunkSlice(len(oldIDs), batchSize) {
Expand Down Expand Up @@ -210,109 +208,139 @@ func (r *RDBDriver) deleteAndInsertExploit(exploitType models.ExploitType, explo
}

// GetExploitByID :
func (r *RDBDriver) GetExploitByID(exploitUniqueID string) []models.Exploit {
func (r *RDBDriver) GetExploitByID(exploitUniqueID string) ([]models.Exploit, error) {
es := []models.Exploit{}
var errs util.Errors
errs = errs.Add(r.conn.Where(&models.Exploit{ExploitUniqueID: exploitUniqueID}).Find(&es).Error)
if err := r.conn.Where(&models.Exploit{ExploitUniqueID: exploitUniqueID}).Find(&es).Error; err != nil {
log15.Error("Failed to get exploit by ExploitDB-ID", "err", err)
return nil, err
}
for i := range es {
switch es[i].ExploitType {
case models.OffensiveSecurityType:
errs = errs.Add(r.conn.
if err := r.conn.
Preload(clause.Associations).
Where(&models.OffensiveSecurity{ExploitID: es[i].ID}).
Take(&es[i].OffensiveSecurity).Error)
Take(&es[i].OffensiveSecurity).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get OffensiveSecurity. DB relationship may be broken, use `$ go-exploitdb fetch exploitdb` to recreate DB. err: %w", err)
}
log15.Error("Failed to get OffensiveSecurity", "err", err)
return nil, err
}
case models.GitHubRepositoryType:
errs = errs.Add(r.conn.Where(&models.GitHubRepository{ExploitID: es[i].ID}).Take(&es[i].GitHubRepository).Error)
}
}
for _, e := range errs.GetErrors() {
if !errors.Is(e, gorm.ErrRecordNotFound) {
log15.Error("Failed to get exploit by ExploitDB-ID", "err", e)
if err := r.conn.Where(&models.GitHubRepository{ExploitID: es[i].ID}).Take(&es[i].GitHubRepository).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get GitHubRepository. DB relationship may be broken, use `$ go-exploitdb fetch githubrepos` to recreate DB. err: %w", err)
}
log15.Error("Failed to get GitHubRepository", "err", err)
return nil, err
}
}
}
return es
return es, nil
}

// GetExploitAll :
func (r *RDBDriver) GetExploitAll() []models.Exploit {
func (r *RDBDriver) GetExploitAll() ([]models.Exploit, error) {
es := []models.Exploit{}

rows, err := r.conn.Model(&models.Exploit{}).Rows()
if err != nil {
log15.Error("Failed to Rows", "err", err)
return nil
return nil, err
}
defer rows.Close()

for rows.Next() {
exploit := models.Exploit{}
if err := r.conn.ScanRows(rows, &exploit); err != nil {
log15.Error("Failed to ScanRows", "err", err)
return nil
return nil, err
}
switch exploit.ExploitType {
case models.OffensiveSecurityType:
if err := r.conn.
Preload(clause.Associations).
Where(&models.OffensiveSecurity{ExploitID: exploit.ID}).
Take(&exploit.OffensiveSecurity).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get OffensiveSecurity. DB relationship may be broken, use `$ go-exploitdb fetch exploitdb` to recreate DB. err: %w", err)
}
log15.Error("Failed to Get OffensiveSecurity", "err", err)
return nil
return nil, err
}
case models.GitHubRepositoryType:
if err := r.conn.Where(&models.GitHubRepository{ExploitID: exploit.ID}).Take(&exploit.GitHubRepository).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get GitHubRepository. DB relationship may be broken, use `$ go-exploitdb fetch githubrepos` to recreate DB. err: %w", err)
}
log15.Error("Failed to Get GitHubRepository", "err", err)
return nil
return nil, err
}
}
es = append(es, exploit)
}

return es
return es, nil
}

// GetExploitMultiByID :
func (r *RDBDriver) GetExploitMultiByID(exploitUniqueIDs []string) map[string][]models.Exploit {
func (r *RDBDriver) GetExploitMultiByID(exploitUniqueIDs []string) (map[string][]models.Exploit, error) {
exploits := map[string][]models.Exploit{}
for _, exploitUniqueID := range exploitUniqueIDs {
exploits[exploitUniqueID] = r.GetExploitByID(exploitUniqueID)
exploit, err := r.GetExploitByID(exploitUniqueID)
if err != nil {
return nil, err
}
exploits[exploitUniqueID] = exploit
}
return exploits
return exploits, nil
}

// GetExploitByCveID :
func (r *RDBDriver) GetExploitByCveID(cveID string) []models.Exploit {
func (r *RDBDriver) GetExploitByCveID(cveID string) ([]models.Exploit, error) {
es := []models.Exploit{}
var errs util.Errors

errs = errs.Add(r.conn.Where(&models.Exploit{CveID: cveID}).Find(&es).Error)
if err := r.conn.Where(&models.Exploit{CveID: cveID}).Find(&es).Error; err != nil {
log15.Error("Failed to get exploit by ExploitDB-ID", "err", err)
return nil, err
}
for i := range es {
switch es[i].ExploitType {
case models.OffensiveSecurityType:
errs = errs.Add(r.conn.
if err := r.conn.
Preload(clause.Associations).
Where(&models.OffensiveSecurity{ExploitID: es[i].ID}).
Take(&es[i].OffensiveSecurity).Error)
Take(&es[i].OffensiveSecurity).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get OffensiveSecurity. DB relationship may be broken, use `$ go-exploitdb fetch exploitdb` to recreate DB. err: %w", err)
}
log15.Error("Failed to get OffensiveSecurity", "err", err)
return nil, err
}
case models.GitHubRepositoryType:
errs = errs.Add(r.conn.Where(&models.GitHubRepository{ExploitID: es[i].ID}).Take(&es[i].GitHubRepository).Error)
}
}

for _, e := range errs.GetErrors() {
if !errors.Is(e, gorm.ErrRecordNotFound) {
log15.Error("Failed to get exploit by CveID", "err", e)
if err := r.conn.Where(&models.GitHubRepository{ExploitID: es[i].ID}).Take(&es[i].GitHubRepository).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, xerrors.Errorf("Failed to get GitHubRepository. DB relationship may be broken, use `$ go-exploitdb fetch githubrepos` to recreate DB. err: %w", err)
}
log15.Error("Failed to get GitHubRepository", "err", err)
return nil, err
}
}
}
return es
return es, nil
}

// GetExploitMultiByCveID :
func (r *RDBDriver) GetExploitMultiByCveID(cveIDs []string) (exploits map[string][]models.Exploit) {
exploits = map[string][]models.Exploit{}
func (r *RDBDriver) GetExploitMultiByCveID(cveIDs []string) (map[string][]models.Exploit, error) {
exploits := map[string][]models.Exploit{}
for _, cveID := range cveIDs {
exploits[cveID] = r.GetExploitByCveID(cveID)
exploit, err := r.GetExploitByCveID(cveID)
if err != nil {
return nil, err
}
exploits[cveID] = exploit
}
return exploits
return exploits, nil
}

// IsExploitModelV1 determines if the DB was created at the time of go-exploitdb Model v1
Expand Down
Loading

0 comments on commit 10b78a3

Please sign in to comment.