-
Notifications
You must be signed in to change notification settings - Fork 589
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into mattverse/disable-stableswap
- Loading branch information
Showing
6 changed files
with
513 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package osmoutils | ||
|
||
import ( | ||
"errors" | ||
|
||
sdk "github.com/cosmos/cosmos-sdk/types" | ||
) | ||
|
||
// ErrTolerance is used to define a compare function, which checks if two | ||
// ints are within a certain error tolerance of one another. | ||
// ErrTolerance.Compare(a, b) returns true iff: | ||
// |a - b| <= AdditiveTolerance | ||
// |a - b| / min(a, b) <= MultiplicativeTolerance | ||
// Each check is respectively ignored if the entry is nil (sdk.Dec{}, sdk.Int{}) | ||
// Note that if AdditiveTolerance == 0, then this is equivalent to a standard compare. | ||
type ErrTolerance struct { | ||
AdditiveTolerance sdk.Int | ||
MultiplicativeTolerance sdk.Dec | ||
} | ||
|
||
// Compare returns if actual is within errTolerance of expected. | ||
// returns 0 if it is | ||
// returns 1 if not, and expected > actual. | ||
// returns -1 if not, and expected < actual | ||
func (e ErrTolerance) Compare(expected sdk.Int, actual sdk.Int) int { | ||
diff := expected.Sub(actual).Abs() | ||
|
||
comparisonSign := 0 | ||
if expected.GT(actual) { | ||
comparisonSign = 1 | ||
} else { | ||
comparisonSign = -1 | ||
} | ||
|
||
// if no error accepted, do a direct compare. | ||
if e.AdditiveTolerance.IsZero() { | ||
if expected.Equal(actual) { | ||
return 0 | ||
} else { | ||
return comparisonSign | ||
} | ||
} | ||
|
||
// Check additive tolerance equations | ||
if !e.AdditiveTolerance.IsNil() && !e.AdditiveTolerance.IsZero() { | ||
if diff.GT(e.AdditiveTolerance) { | ||
return comparisonSign | ||
} | ||
} | ||
// Check multiplicative tolerance equations | ||
if !e.MultiplicativeTolerance.IsNil() && !e.MultiplicativeTolerance.IsZero() { | ||
errTerm := diff.ToDec().Quo(sdk.MinInt(expected, actual).ToDec()) | ||
if errTerm.GT(e.MultiplicativeTolerance) { | ||
return comparisonSign | ||
} | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
// Binary search inputs between [lowerbound, upperbound] to a monotonic increasing function f. | ||
// We stop once f(found_input) meets the ErrTolerance constraints. | ||
// If we perform more than maxIterations (or equivalently lowerbound = upperbound), we return an error. | ||
func BinarySearch(f func(input sdk.Int) (sdk.Int, error), | ||
lowerbound sdk.Int, | ||
upperbound sdk.Int, | ||
targetOutput sdk.Int, | ||
errTolerance ErrTolerance, | ||
maxIterations int) (sdk.Int, error) { | ||
// Setup base case of loop | ||
curEstimate := lowerbound.Add(upperbound).QuoRaw(2) | ||
curOutput, err := f(curEstimate) | ||
if err != nil { | ||
return sdk.Int{}, err | ||
} | ||
curIteration := 0 | ||
for ; curIteration < maxIterations; curIteration += 1 { | ||
compRes := errTolerance.Compare(curOutput, targetOutput) | ||
if compRes > 0 { | ||
upperbound = curEstimate | ||
} else if compRes < 0 { | ||
lowerbound = curEstimate | ||
} else { | ||
break | ||
} | ||
curEstimate = lowerbound.Add(upperbound).QuoRaw(2) | ||
curOutput, err = f(curEstimate) | ||
if err != nil { | ||
return sdk.Int{}, err | ||
} | ||
} | ||
if curIteration == maxIterations { | ||
return sdk.Int{}, errors.New("hit maximum iterations, did not converge fast enough") | ||
} | ||
return curEstimate, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package osmoutils | ||
|
||
import ( | ||
"testing" | ||
|
||
sdk "github.com/cosmos/cosmos-sdk/types" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestBinarySearch(t *testing.T) { | ||
// straight line function that returns input. Simplest to binary search on, | ||
// binary search directly reveals one bit of the answer in each iteration with this function. | ||
lineF := func(a sdk.Int) (sdk.Int, error) { | ||
return a, nil | ||
} | ||
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} | ||
tests := []struct { | ||
f func(sdk.Int) (sdk.Int, error) | ||
lowerbound sdk.Int | ||
upperbound sdk.Int | ||
targetOutput sdk.Int | ||
errTolerance ErrTolerance | ||
maxIterations int | ||
|
||
expectedSolvedInput sdk.Int | ||
expectErr bool | ||
}{ | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, | ||
} | ||
|
||
for _, tc := range tests { | ||
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) | ||
if tc.expectErr { | ||
require.Error(t, err) | ||
} else { | ||
require.NoError(t, err) | ||
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) | ||
} | ||
} | ||
} | ||
|
||
func TestBinarySearchNonlinear(t *testing.T) { | ||
// straight line function that returns input. Simplest to binary search on, | ||
// binary search directly reveals one bit of the answer in each iteration with this function. | ||
lineF := func(a sdk.Int) (sdk.Int, error) { | ||
return a, nil | ||
} | ||
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} | ||
tests := []struct { | ||
f func(sdk.Int) (sdk.Int, error) | ||
lowerbound sdk.Int | ||
upperbound sdk.Int | ||
targetOutput sdk.Int | ||
errTolerance ErrTolerance | ||
maxIterations int | ||
|
||
expectedSolvedInput sdk.Int | ||
expectErr bool | ||
}{ | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, | ||
} | ||
|
||
for _, tc := range tests { | ||
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) | ||
if tc.expectErr { | ||
require.Error(t, err) | ||
} else { | ||
require.NoError(t, err) | ||
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) | ||
} | ||
} | ||
} | ||
|
||
func TestBinarySearchNonlinearNonzero(t *testing.T) { | ||
// non-linear function that returns input. Simplest to binary search on, | ||
// binary search directly reveals one bit of the answer in each iteration with this function. | ||
lineF := func(a sdk.Int) (sdk.Int, error) { | ||
return a, nil | ||
} | ||
noErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt()} | ||
tests := []struct { | ||
f func(sdk.Int) (sdk.Int, error) | ||
lowerbound sdk.Int | ||
upperbound sdk.Int | ||
targetOutput sdk.Int | ||
errTolerance ErrTolerance | ||
maxIterations int | ||
|
||
expectedSolvedInput sdk.Int | ||
expectErr bool | ||
}{ | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 51, sdk.NewInt(1 + (1 << 25)), false}, | ||
{lineF, sdk.ZeroInt(), sdk.NewInt(1 << 50), sdk.NewInt(1 + (1 << 25)), noErrTolerance, 10, sdk.Int{}, true}, | ||
} | ||
|
||
for _, tc := range tests { | ||
actualSolvedInput, err := BinarySearch(tc.f, tc.lowerbound, tc.upperbound, tc.targetOutput, tc.errTolerance, tc.maxIterations) | ||
if tc.expectErr { | ||
require.Error(t, err) | ||
} else { | ||
require.NoError(t, err) | ||
require.True(sdk.IntEq(t, tc.expectedSolvedInput, actualSolvedInput)) | ||
} | ||
} | ||
} | ||
|
||
func TestErrTolerance_Compare(t *testing.T) { | ||
ZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.ZeroInt(), MultiplicativeTolerance: sdk.Dec{}} | ||
tests := []struct { | ||
name string | ||
tol ErrTolerance | ||
input sdk.Int | ||
reference sdk.Int | ||
|
||
expectedCompareResult int | ||
}{ | ||
{"0 tolerance: <", ZeroErrTolerance, sdk.NewInt(1000), sdk.NewInt(1001), -1}, | ||
{"0 tolerance: =", ZeroErrTolerance, sdk.NewInt(1001), sdk.NewInt(1001), 0}, | ||
{"0 tolerance: >", ZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 1}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult { | ||
t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestErrToleranceNonzero_Compare(t *testing.T) { | ||
// Nonzero error tolerance test | ||
NonZeroErrTolerance := ErrTolerance{AdditiveTolerance: sdk.NewInt(10), MultiplicativeTolerance: sdk.Dec{}} | ||
tests := []struct { | ||
name string | ||
tol ErrTolerance | ||
input sdk.Int | ||
reference sdk.Int | ||
|
||
expectedCompareResult int | ||
}{ | ||
{"Nonzero tolerance: <", NonZeroErrTolerance, sdk.NewInt(420), sdk.NewInt(1001), -1}, | ||
{"Nonzero tolerance: =", NonZeroErrTolerance, sdk.NewInt(1002), sdk.NewInt(1001), 0}, | ||
{"Nonzero tolerance: >", NonZeroErrTolerance, sdk.NewInt(1230), sdk.NewInt(1001), 1}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
if got := tt.tol.Compare(tt.input, tt.reference); got != tt.expectedCompareResult { | ||
t.Errorf("ErrTolerance.Compare() = %v, want %v", got, tt.expectedCompareResult) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.