Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more ptr utils + errkit improvements #573

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func ExpandWithEnv(variables ...*string) {

// EnvType is a type that can be used as a type for environment variables.
type EnvType interface {
~string | ~int | ~bool | ~float64 | time.Duration
~string | ~int | ~bool | ~float64 | time.Duration | ~rune
}

// GetEnvOrDefault returns the value of the environment variable or the default value if the variable is not set.
Expand Down
199 changes: 153 additions & 46 deletions errkit/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"errors"
"fmt"
"log/slog"
"runtime"
"strconv"
"strings"
"time"

"github.com/projectdiscovery/utils/env"
"golang.org/x/exp/maps"
)

const (
Expand All @@ -24,38 +26,78 @@ const (
DelimMultiLine = "\n - "
// MultiLinePrefix is the prefix used for multiline errors
MultiLineErrPrefix = "the following errors occurred:"
// Space is the identifier used for indentation
Space = " "
)

var (
// MaxErrorDepth is the maximum depth of errors to be unwrapped or maintained
// all errors beyond this depth will be ignored
MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3)
// ErrorSeperator is the seperator used to join errors
ErrorSeperator = env.GetEnvOrDefault("ERROR_SEPERATOR", "; ")
// FieldSeperator
ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Space)
// ErrChainSeperator
ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon)
// EnableTimestamp controls whether error timestamps are included
EnableTimestamp = env.GetEnvOrDefault("ENABLE_ERR_TIMESTAMP", false)
// EnableTrace controls whether error stack traces are included
EnableTrace = env.GetEnvOrDefault("ENABLE_ERR_TRACE", false)
)

// ErrorX is a custom error type that can handle all known types of errors
// wrapping and joining strategies including custom ones and it supports error class
// which can be shown to client/users in more meaningful way
type ErrorX struct {
kind ErrKind
attrs map[string]slog.Attr
errs []error
uniqErrs map[string]struct{}
kind ErrKind
record *slog.Record
source *slog.Source
errs []error
}

func (e *ErrorX) init(skipStack ...int) {
// initializes if necessary
if e.record == nil {
e.record = &slog.Record{}
if EnableTimestamp {
e.record.Time = time.Now()
}
if EnableTrace {
// get fn name
var pcs [1]uintptr
// skip [runtime.Callers, ErrorX.init, parent]
skip := 3
if len(skipStack) > 0 {
skip = skipStack[0]
}
runtime.Callers(skip, pcs[:])
pc := pcs[0]
fs := runtime.CallersFrames([]uintptr{pc})
f, _ := fs.Next()
e.source = &slog.Source{
Function: f.Function,
File: f.File,
Line: f.Line,
}
}
}
}

// append is internal method to append given
// error to error slice , it removes duplicates
// earlier it used map which causes more allocations that necessary
func (e *ErrorX) append(errs ...error) {
if e.uniqErrs == nil {
e.uniqErrs = make(map[string]struct{})
}
for _, err := range errs {
if _, ok := e.uniqErrs[err.Error()]; ok {
continue
for _, nerr := range errs {
found := false
new:
for _, oerr := range e.errs {
if oerr.Error() == nerr.Error() {
found = true
break new
}
}
if !found {
e.errs = append(e.errs, nerr)
}
e.uniqErrs[err.Error()] = struct{}{}
e.errs = append(e.errs, err)
}
}

Expand All @@ -71,8 +113,11 @@ func (e ErrorX) MarshalJSON() ([]byte, error) {
"kind": e.kind.String(),
"errors": tmp,
}
if len(e.attrs) > 0 {
m["attrs"] = slog.GroupValue(maps.Values(e.attrs)...)
if e.record != nil && e.record.NumAttrs() > 0 {
m["attrs"] = slog.GroupValue(e.Attrs()...)
}
if e.source != nil {
m["source"] = e.source
}
return json.Marshal(m)
}
Expand All @@ -84,10 +129,15 @@ func (e *ErrorX) Errors() []error {

// Attrs returns all attributes associated with the error
func (e *ErrorX) Attrs() []slog.Attr {
if e.attrs == nil {
if e.record == nil || e.record.NumAttrs() == 0 {
return nil
}
return maps.Values(e.attrs)
values := []slog.Attr{}
e.record.Attrs(func(a slog.Attr) bool {
values = append(values, a)
return true
})
return values
}

// Build returns the object as error interface
Expand All @@ -103,6 +153,7 @@ func (e *ErrorX) Unwrap() []error {
// Is checks if current error contains given error
func (e *ErrorX) Is(err error) bool {
x := &ErrorX{}
x.init()
parseError(x, err)
// even one submatch is enough
for _, orig := range e.errs {
Expand All @@ -118,20 +169,26 @@ func (e *ErrorX) Is(err error) bool {
// Error returns the error string
func (e *ErrorX) Error() string {
var sb strings.Builder
if e.kind != nil && e.kind.String() != "" {
sb.WriteString("errKind=")
sb.WriteString(e.kind.String())
sb.WriteString(" ")
}
if len(e.attrs) > 0 {
sb.WriteString(slog.GroupValue(maps.Values(e.attrs)...).String())
sb.WriteString(" ")
sb.WriteString("cause=")
sb.WriteString(strconv.Quote(e.errs[0].Error()))
if e.record != nil && e.record.NumAttrs() > 0 {
values := []string{}
e.record.Attrs(func(a slog.Attr) bool {
values = append(values, a.String())
return true
})
sb.WriteString(Space)
sb.WriteString(strings.Join(values, " "))
}
for _, err := range e.errs {
sb.WriteString(err.Error())
sb.WriteString(ErrorSeperator)
if len(e.errs) > 1 {
chain := []string{}
for _, value := range e.errs[1:] {
chain = append(chain, strings.TrimSpace(value.Error()))
}
sb.WriteString(Space)
sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator)))
}
return strings.TrimSuffix(sb.String(), ErrorSeperator)
return sb.String()
}

// Cause return the original error that caused this without any wrapping
Expand All @@ -158,28 +215,65 @@ func FromError(err error) *ErrorX {
return nil
}
nucleiErr := &ErrorX{}
nucleiErr.init()
parseError(nucleiErr, err)
return nucleiErr
}

// New creates a new error with the given message
func New(format string, args ...interface{}) *ErrorX {
// it follows slog pattern of adding and expects in the same way
//
// Example:
//
// this is correct (√)
// errkit.New("this is a nuclei error","address",host)
//
// this is not readable/recommended (x)
// errkit.New("this is a nuclei error",slog.String("address",host))
//
// this is wrong (x)
// errkit.New("this is a nuclei error %s",host)
func New(msg string, args ...interface{}) *ErrorX {
e := &ErrorX{}
e.append(fmt.Errorf(format, args...))
e.init()
if len(args) > 0 {
e.record.Add(args...)
}
e.append(errors.New(msg))
return e
}

// Msgf adds a message to the error
// it follows slog pattern of adding and expects in the same way
//
// Example:
//
// this is correct (√)
// myError.Msgf("dial error","network","tcp")
//
// this is not readable/recommended (x)
// myError.Msgf(slog.String("address",host))
//
// this is wrong (x)
// myError.Msgf("this is a nuclei error %s",host)
func (e *ErrorX) Msgf(format string, args ...interface{}) {
if e == nil {
return
}
if len(args) == 0 {
e.append(errors.New(format))
}
e.append(fmt.Errorf(format, args...))
}

// SetClass sets the class of the error
// if underlying error class was already set, then it is given preference
// when generating final error msg
//
// Example:
//
// this is correct (√)
// myError.SetKind(errkit.ErrKindNetworkPermanent)
func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
if e.kind == nil {
e.kind = kind
Expand All @@ -189,23 +283,30 @@ func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
return e
}

// ResetKind resets the error class of the error
//
// Example:
//
// myError.ResetKind()
func (e *ErrorX) ResetKind() *ErrorX {
e.kind = nil
return e
}

// Deprecated: use Attrs instead
//
// SetAttr sets additional attributes to a given error
// it only adds unique attributes and ignores duplicates
// Note: only key is checked for uniqueness
//
// Example:
//
// this is correct (√)
// myError.SetAttr(slog.String("address",host))
func (e *ErrorX) SetAttr(s ...slog.Attr) *ErrorX {
e.init()
for _, attr := range s {
if e.attrs == nil {
e.attrs = make(map[string]slog.Attr)
}
// check if this exists
if _, ok := e.attrs[attr.Key]; !ok && len(e.attrs) < MaxErrorDepth {
e.attrs[attr.Key] = attr
}
e.record.Add(attr)
}
return e
}
Expand All @@ -217,6 +318,7 @@ func parseError(to *ErrorX, err error) {
}
if to == nil {
to = &ErrorX{}
to.init(4)
}
if len(to.errs) >= MaxErrorDepth {
return
Expand All @@ -225,6 +327,17 @@ func parseError(to *ErrorX, err error) {
switch v := err.(type) {
case *ErrorX:
to.append(v.errs...)
if to.record == nil {
to.record = v.record
} else {
v.record.Attrs(func(a slog.Attr) bool {
to.record.Add(a)
return true
})
}
if to.source == nil {
to.source = v.source
}
to.kind = CombineErrKinds(to.kind, v.kind)
case JoinedError:
foundAny := false
Expand Down Expand Up @@ -283,9 +396,3 @@ func parseError(to *ErrorX, err error) {
}
}
}

// WrappedError is implemented by errors that are wrapped
type WrappedError interface {
// Unwrap returns the underlying error
Unwrap() error
}
12 changes: 12 additions & 0 deletions errkit/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,15 @@ func TestMarshalError(t *testing.T) {
require.NoError(t, err, "expected to be able to marshal the error")
require.Equal(t, `{"errors":["port closed or filtered","this is a wrapped error"],"kind":"network-permanent-error"}`, string(marshalled))
}

func TestErrorString(t *testing.T) {
var x error = New("i/o timeout")
x = With(x, "ip", "10.0.0.1", "port", 80)
x = WithMessage(x, "tcp dial error")
x = Append(x, errors.New("some other error"))

require.Equal(t,
`cause="i/o timeout" ip=10.0.0.1 port=80 chain="tcp dial error; some other error"`,
x.Error(),
)
}
14 changes: 8 additions & 6 deletions errkit/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,21 @@ func IsNetworkPermanentErr(err error) bool {
return isNetworkPermanentErr(x)
}

// WithAttr wraps error with given attributes
// With adds extra attributes to the error
//
// err = errkit.WithAttr(err,slog.Any("resource",domain))
func WithAttr(err error, attrs ...slog.Attr) error {
// err = errkit.With(err,"resource",domain)
func With(err error, args ...any) error {
if err == nil {
return nil
}
if len(attrs) == 0 {
if len(args) == 0 {
return err
}
x := &ErrorX{}
x.init()
parseError(x, err)
return x.SetAttr(attrs...)
x.record.Add(args...)
return x
}

// GetAttr returns all attributes of given error if it has any
Expand Down Expand Up @@ -271,7 +273,7 @@ func GetAttrValue(err error, key string) slog.Value {
}
x := &ErrorX{}
parseError(x, err)
for _, attr := range x.attrs {
for _, attr := range x.Attrs() {
if attr.Key == key {
return attr.Value
}
Expand Down
6 changes: 6 additions & 0 deletions errkit/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ type ComparableError interface {
// Is checks if current error contains given error
Is(err error) bool
}

// WrappedError is implemented by errors that are wrapped
type WrappedError interface {
// Unwrap returns the underlying error
Unwrap() error
}
Loading
Loading