diff --git a/slices/cmp.go b/slices/cmp.go new file mode 100644 index 000000000..fbf1934a0 --- /dev/null +++ b/slices/cmp.go @@ -0,0 +1,44 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import "golang.org/x/exp/constraints" + +// min is a version of the predeclared function from the Go 1.21 release. +func min[T constraints.Ordered](a, b T) T { + if a < b || isNaN(a) { + return a + } + return b +} + +// max is a version of the predeclared function from the Go 1.21 release. +func max[T constraints.Ordered](a, b T) T { + if a > b || isNaN(a) { + return a + } + return b +} + +// cmpLess is a copy of cmp.Less from the Go 1.21 release. +func cmpLess[T constraints.Ordered](x, y T) bool { + return (isNaN(x) && !isNaN(y)) || x < y +} + +// cmpCompare is a copy of cmp.Compare from the Go 1.21 release. +func cmpCompare[T constraints.Ordered](x, y T) int { + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} diff --git a/slices/slices.go b/slices/slices.go index 8a7cf20db..5e8158bba 100644 --- a/slices/slices.go +++ b/slices/slices.go @@ -3,23 +3,20 @@ // license that can be found in the LICENSE file. // Package slices defines various functions useful with slices of any type. -// Unless otherwise specified, these functions all apply to the elements -// of a slice at index 0 <= i < len(s). -// -// Note that the less function in IsSortedFunc, SortFunc, SortStableFunc requires a -// strict weak ordering (https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings), -// or the sorting may fail to sort correctly. A common case is when sorting slices of -// floating-point numbers containing NaN values. package slices -import "golang.org/x/exp/constraints" +import ( + "unsafe" + + "golang.org/x/exp/constraints" +) // Equal reports whether two slices are equal: the same length and all // elements equal. If the lengths are different, Equal returns false. // Otherwise, the elements are compared in increasing index order, and the // comparison stops at the first unequal pair. // Floating point NaNs are not considered equal. -func Equal[E comparable](s1, s2 []E) bool { +func Equal[S ~[]E, E comparable](s1, s2 S) bool { if len(s1) != len(s2) { return false } @@ -31,12 +28,12 @@ func Equal[E comparable](s1, s2 []E) bool { return true } -// EqualFunc reports whether two slices are equal using a comparison +// EqualFunc reports whether two slices are equal using an equality // function on each pair of elements. If the lengths are different, // EqualFunc returns false. Otherwise, the elements are compared in // increasing index order, and the comparison stops at the first index // for which eq returns false. -func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { +func EqualFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, eq func(E1, E2) bool) bool { if len(s1) != len(s2) { return false } @@ -49,45 +46,37 @@ func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { return true } -// Compare compares the elements of s1 and s2. -// The elements are compared sequentially, starting at index 0, +// Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair +// of elements. The elements are compared sequentially, starting at index 0, // until one element is not equal to the other. // The result of comparing the first non-matching elements is returned. // If both slices are equal until one of them ends, the shorter slice is // considered less than the longer one. // The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. -// Comparisons involving floating point NaNs are ignored. -func Compare[E constraints.Ordered](s1, s2 []E) int { - s2len := len(s2) +func Compare[S ~[]E, E constraints.Ordered](s1, s2 S) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] - switch { - case v1 < v2: - return -1 - case v1 > v2: - return +1 + if c := cmpCompare(v1, v2); c != 0 { + return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 } -// CompareFunc is like Compare but uses a comparison function -// on each pair of elements. The elements are compared in increasing -// index order, and the comparisons stop after the first time cmp -// returns non-zero. +// CompareFunc is like [Compare] but uses a custom comparison function on each +// pair of elements. // The result is the first non-zero result of cmp; if cmp always // returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), // and +1 if len(s1) > len(s2). -func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { - s2len := len(s2) +func CompareFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, cmp func(E1, E2) int) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] @@ -95,7 +84,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 @@ -103,7 +92,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { // Index returns the index of the first occurrence of v in s, // or -1 if not present. -func Index[E comparable](s []E, v E) int { +func Index[S ~[]E, E comparable](s S, v E) int { for i := range s { if v == s[i] { return i @@ -114,7 +103,7 @@ func Index[E comparable](s []E, v E) int { // IndexFunc returns the first index i satisfying f(s[i]), // or -1 if none do. -func IndexFunc[E any](s []E, f func(E) bool) int { +func IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { for i := range s { if f(s[i]) { return i @@ -124,39 +113,104 @@ func IndexFunc[E any](s []E, f func(E) bool) int { } // Contains reports whether v is present in s. -func Contains[E comparable](s []E, v E) bool { +func Contains[S ~[]E, E comparable](s S, v E) bool { return Index(s, v) >= 0 } // ContainsFunc reports whether at least one // element e of s satisfies f(e). -func ContainsFunc[E any](s []E, f func(E) bool) bool { +func ContainsFunc[S ~[]E, E any](s S, f func(E) bool) bool { return IndexFunc(s, f) >= 0 } // Insert inserts the values v... into s at index i, // returning the modified slice. -// In the returned slice r, r[i] == v[0]. +// The elements at s[i:] are shifted up to make room. +// In the returned slice r, r[i] == v[0], +// and r[i+len(v)] == value originally at r[i]. // Insert panics if i is out of range. // This function is O(len(s) + len(v)). func Insert[S ~[]E, E any](s S, i int, v ...E) S { - tot := len(s) + len(v) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[i:]) + m := len(v) + if m == 0 { + return s + } + n := len(s) + if i == n { + return append(s, v...) + } + if n+m > cap(s) { + // Use append rather than make so that we bump the size of + // the slice up to the next storage class. + // This is what Grow does but we don't call Grow because + // that might copy the values twice. + s2 := append(s[:i], make(S, n+m-i)...) copy(s2[i:], v) + copy(s2[i+m:], s[i:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[i:]) - return s2 + s = s[:n+m] + + // before: + // s: aaaaaaaabbbbccccccccdddd + // ^ ^ ^ ^ + // i i+m n n+m + // after: + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // + // a are the values that don't move in s. + // v are the values copied in from v. + // b and c are the values from s that are shifted up in index. + // d are the values that get overwritten, never to be seen again. + + if !overlaps(v, s[i+m:]) { + // Easy case - v does not overlap either the c or d regions. + // (It might be in some of a or b, or elsewhere entirely.) + // The data we copy up doesn't write to v at all, so just do it. + + copy(s[i+m:], s[i:]) + + // Now we have + // s: aaaaaaaabbbbbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // Note the b values are duplicated. + + copy(s[i:], v) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s + } + + // The hard case - v overlaps c or d. We can't just shift up + // the data because we'd move or clobber the values we're trying + // to insert. + // So instead, write v on top of d, then rotate. + copy(s[n:], v) + + // Now we have + // s: aaaaaaaabbbbccccccccvvvv + // ^ ^ ^ ^ + // i i+m n n+m + + rotateRight(s[i:], m) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s } // Delete removes the elements s[i:j] from s, returning the modified slice. // Delete panics if s[i:j] is not a valid slice of s. -// Delete modifies the contents of the slice s; it does not create a new slice. // Delete is O(len(s)-j), so if many items must be deleted, it is better to // make a single call deleting them all together than to delete one at a time. // Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those @@ -175,39 +229,106 @@ func Delete[S ~[]E, E any](s S, i, j int) S { // zeroing those elements so that objects they reference can be garbage // collected. func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := IndexFunc(s, del) + if i == -1 { + return s + } // Don't start copying elements until we find one to delete. - for i, v := range s { - if del(v) { - j := i - for i++; i < len(s); i++ { - v = s[i] - if !del(v) { - s[j] = v - j++ - } - } - return s[:j] + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ } } - return s + return s[:i] } // Replace replaces the elements s[i:j] by the given v, and returns the // modified slice. Replace panics if s[i:j] is not a valid slice of s. func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { _ = s[i:j] // verify that i:j is a valid subslice + + if i == j { + return Insert(s, i, v...) + } + if j == len(s) { + return append(s[:i], v...) + } + tot := len(s[:i]) + len(v) + len(s[j:]) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[j:]) + if tot > cap(s) { + // Too big to fit, allocate and copy over. + s2 := append(s[:i], make(S, tot-i)...) // See Insert copy(s2[i:], v) + copy(s2[i+len(v):], s[j:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[j:]) - return s2 + + r := s[:tot] + + if i+len(v) <= j { + // Easy, as v fits in the deleted portion. + copy(r[i:], v) + if i+len(v) != j { + copy(r[i+len(v):], s[j:]) + } + return r + } + + // We are expanding (v is bigger than j-i). + // The situation is something like this: + // (example has i=4,j=8,len(s)=16,len(v)=6) + // s: aaaaxxxxbbbbbbbbyy + // ^ ^ ^ ^ + // i j len(s) tot + // a: prefix of s + // x: deleted range + // b: more of s + // y: area to expand into + + if !overlaps(r[i+len(v):], v) { + // Easy, as v is not clobbered by the first copy. + copy(r[i+len(v):], s[j:]) + copy(r[i:], v) + return r + } + + // This is a situation where we don't have a single place to which + // we can copy v. Parts of it need to go to two different places. + // We want to copy the prefix of v into y and the suffix into x, then + // rotate |y| spots to the right. + // + // v[2:] v[:2] + // | | + // s: aaaavvvvbbbbbbbbvv + // ^ ^ ^ ^ + // i j len(s) tot + // + // If either of those two destinations don't alias v, then we're good. + y := len(v) - (j - i) // length of y portion + + if !overlaps(r[i:j], v) { + copy(r[i:j], v[y:]) + copy(r[len(s):], v[:y]) + rotateRight(r[i:], y) + return r + } + if !overlaps(r[len(s):], v) { + copy(r[len(s):], v[:y]) + copy(r[i:j], v[y:]) + rotateRight(r[i:], y) + return r + } + + // Now we know that v overlaps both x and y. + // That means that the entirety of b is *inside* v. + // So we don't need to preserve b at all; instead we + // can copy v first, then copy the b part of v out of + // v to the right destination. + k := startIdx(v, s[j:]) + copy(r[i:], v) + copy(r[i+len(v):], r[i+k:]) + return r } // Clone returns a copy of the slice. @@ -222,7 +343,8 @@ func Clone[S ~[]E, E any](s S) S { // Compact replaces consecutive runs of equal elements with a single copy. // This is like the uniq command found on Unix. -// Compact modifies the contents of the slice s; it does not create a new slice. +// Compact modifies the contents of the slice s and returns the modified slice, +// which may have a smaller length. // When Compact discards m elements in total, it might not modify the elements // s[len(s)-m:len(s)]. If those elements contain pointers you might consider // zeroing those elements so that objects they reference can be garbage collected. @@ -242,7 +364,8 @@ func Compact[S ~[]E, E comparable](s S) S { return s[:i] } -// CompactFunc is like Compact but uses a comparison function. +// CompactFunc is like [Compact] but uses an equality function to compare elements. +// For runs of elements that compare equal, CompactFunc keeps the first one. func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { if len(s) < 2 { return s @@ -280,3 +403,97 @@ func Grow[S ~[]E, E any](s S, n int) S { func Clip[S ~[]E, E any](s S) S { return s[:len(s):len(s)] } + +// Rotation algorithm explanation: +// +// rotate left by 2 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join first parts +// 89234567 01 +// recursively rotate first left part by 2 +// 23456789 01 +// join at the end +// 2345678901 +// +// rotate left by 8 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join last parts +// 89 23456701 +// recursively rotate second part left by 6 +// 89 01234567 +// join at the end +// 8901234567 + +// TODO: There are other rotate algorithms. +// This algorithm has the desirable property that it moves each element exactly twice. +// The triple-reverse algorithm is simpler and more cache friendly, but takes more writes. +// The follow-cycles algorithm can be 1-write but it is not very cache friendly. + +// rotateLeft rotates b left by n spaces. +// s_final[i] = s_orig[i+r], wrapping around. +func rotateLeft[E any](s []E, r int) { + for r != 0 && r != len(s) { + if r*2 <= len(s) { + swap(s[:r], s[len(s)-r:]) + s = s[:len(s)-r] + } else { + swap(s[:len(s)-r], s[r:]) + s, r = s[len(s)-r:], r*2-len(s) + } + } +} +func rotateRight[E any](s []E, r int) { + rotateLeft(s, len(s)-r) +} + +// swap swaps the contents of x and y. x and y must be equal length and disjoint. +func swap[E any](x, y []E) { + for i := 0; i < len(x); i++ { + x[i], y[i] = y[i], x[i] + } +} + +// overlaps reports whether the memory ranges a[0:len(a)] and b[0:len(b)] overlap. +func overlaps[E any](a, b []E) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + elemSize := unsafe.Sizeof(a[0]) + if elemSize == 0 { + return false + } + // TODO: use a runtime/unsafe facility once one becomes available. See issue 12445. + // Also see crypto/internal/alias/alias.go:AnyOverlap + return uintptr(unsafe.Pointer(&a[0])) <= uintptr(unsafe.Pointer(&b[len(b)-1]))+(elemSize-1) && + uintptr(unsafe.Pointer(&b[0])) <= uintptr(unsafe.Pointer(&a[len(a)-1]))+(elemSize-1) +} + +// startIdx returns the index in haystack where the needle starts. +// prerequisite: the needle must be aliased entirely inside the haystack. +func startIdx[E any](haystack, needle []E) int { + p := &needle[0] + for i := range haystack { + if p == &haystack[i] { + return i + } + } + // TODO: what if the overlap is by a non-integral number of Es? + panic("needle not found") +} + +// Reverse reverses the elements of the slice in place. +func Reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/slices/slices_test.go b/slices/slices_test.go index c2402dd76..7371bdb88 100644 --- a/slices/slices_test.go +++ b/slices/slices_test.go @@ -8,8 +8,6 @@ import ( "math" "strings" "testing" - - "golang.org/x/exp/constraints" ) var raceEnabled bool @@ -84,7 +82,7 @@ func equalNaN[T comparable](v1, v2 T) bool { } // offByOne returns true if integers v1 and v2 differ by 1. -func offByOne[E constraints.Integer](v1, v2 E) bool { +func offByOne(v1, v2 int) bool { return v1 == v2+1 || v1 == v2-1 } @@ -105,10 +103,10 @@ func TestEqualFunc(t *testing.T) { s1 := []int{1, 2, 3} s2 := []int{2, 3, 4} - if EqualFunc(s1, s1, offByOne[int]) { + if EqualFunc(s1, s1, offByOne) { t.Errorf("EqualFunc(%v, %v, offByOne) = true, want false", s1, s1) } - if !EqualFunc(s1, s2, offByOne[int]) { + if !EqualFunc(s1, s2, offByOne) { t.Errorf("EqualFunc(%v, %v, offByOne) = false, want true", s1, s2) } @@ -140,6 +138,31 @@ var compareIntTests = []struct { s1, s2 []int want int }{ + { + []int{1}, + []int{1}, + 0, + }, + { + []int{1}, + []int{}, + 1, + }, + { + []int{}, + []int{1}, + -1, + }, + { + []int{}, + []int{}, + 0, + }, + { + []int{1, 2, 3}, + []int{1, 2, 3}, + 0, + }, { []int{1, 2, 3}, []int{1, 2, 3, 4}, @@ -160,12 +183,32 @@ var compareIntTests = []struct { []int{1, 2, 3}, +1, }, + { + []int{1, 4, 3}, + []int{1, 2, 3, 8, 9}, + +1, + }, } var compareFloatTests = []struct { s1, s2 []float64 want int }{ + { + []float64{}, + []float64{}, + 0, + }, + { + []float64{1}, + []float64{1}, + 0, + }, + { + []float64{math.NaN()}, + []float64{math.NaN()}, + 0, + }, { []float64{1, 2, math.NaN()}, []float64{1, 2, math.NaN()}, @@ -184,13 +227,23 @@ var compareFloatTests = []struct { { []float64{1, math.NaN(), 3}, []float64{1, 2, math.NaN()}, - 0, + -1, }, { - []float64{1, math.NaN(), 3, 4}, + []float64{1, 2, 3}, []float64{1, 2, math.NaN()}, +1, }, + { + []float64{1, 2, 3}, + []float64{1, math.NaN(), 3}, + +1, + }, + { + []float64{1, math.NaN(), 3, 4}, + []float64{1, 2, math.NaN()}, + -1, + }, } func TestCompare(t *testing.T) { @@ -232,16 +285,6 @@ func equalToCmp[T comparable](eq func(T, T) bool) func(T, T) int { } } -func cmp[T constraints.Ordered](v1, v2 T) int { - if v1 < v2 { - return -1 - } else if v1 > v2 { - return 1 - } else { - return 0 - } -} - func TestCompareFunc(t *testing.T) { intWant := func(want bool) string { if want { @@ -261,19 +304,19 @@ func TestCompareFunc(t *testing.T) { } for _, test := range compareIntTests { - if got := CompareFunc(test.s1, test.s2, cmp[int]); got != test.want { + if got := CompareFunc(test.s1, test.s2, cmpCompare[int]); got != test.want { t.Errorf("CompareFunc(%v, %v, cmp[int]) = %d, want %d", test.s1, test.s2, got, test.want) } } for _, test := range compareFloatTests { - if got := CompareFunc(test.s1, test.s2, cmp[float64]); got != test.want { + if got := CompareFunc(test.s1, test.s2, cmpCompare[float64]); got != test.want { t.Errorf("CompareFunc(%v, %v, cmp[float64]) = %d, want %d", test.s1, test.s2, got, test.want) } } s1 := []int{1, 2, 3} s2 := []int{2, 3, 4} - if got := CompareFunc(s1, s2, equalToCmp(offByOne[int])); got != 0 { + if got := CompareFunc(s1, s2, equalToCmp(offByOne)); got != 0 { t.Errorf("CompareFunc(%v, %v, offByOne) = %d, want 0", s1, s2, got) } @@ -450,6 +493,45 @@ func TestInsert(t *testing.T) { t.Errorf("Insert(%v, %d, %v...) = %v, want %v", test.s, test.i, test.add, got, test.want) } } + + if !raceEnabled { + // Allocations should be amortized. + const count = 50 + n := testing.AllocsPerRun(10, func() { + s := []int{1, 2, 3} + for i := 0; i < count; i++ { + s = Insert(s, 0, 1) + } + }) + if n > count/2 { + t.Errorf("too many allocations inserting %d elements: got %v, want less than %d", count, n, count/2) + } + } +} + +func TestInsertOverlap(t *testing.T) { + const N = 10 + a := make([]int, N) + want := make([]int, 2*N) + for n := 0; n <= N; n++ { // length + for i := 0; i <= n; i++ { // insertion point + for x := 0; x <= N; x++ { // start of inserted data + for y := x; y <= N; y++ { // end of inserted data + for k := 0; k < N; k++ { + a[k] = k + } + want = want[:0] + want = append(want, a[:i]...) + want = append(want, a[x:y]...) + want = append(want, a[i:n]...) + got := Insert(a[:n], i, a[x:y]...) + if !Equal(got, want) { + t.Errorf("Insert with overlap failed n=%d i=%d x=%d y=%d, got %v want %v", n, i, x, y, got, want) + } + } + } + } + } } var deleteTests = []struct { @@ -748,6 +830,34 @@ func TestClip(t *testing.T) { } } +func TestReverse(t *testing.T) { + even := []int{3, 1, 4, 1, 5, 9} // len = 6 + Reverse(even) + if want := []int{9, 5, 1, 4, 1, 3}; !Equal(even, want) { + t.Errorf("Reverse(even) = %v, want %v", even, want) + } + + odd := []int{3, 1, 4, 1, 5, 9, 2} // len = 7 + Reverse(odd) + if want := []int{2, 9, 5, 1, 4, 1, 3}; !Equal(odd, want) { + t.Errorf("Reverse(odd) = %v, want %v", odd, want) + } + + words := strings.Fields("one two three") + Reverse(words) + if want := strings.Fields("three two one"); !Equal(words, want) { + t.Errorf("Reverse(words) = %v, want %v", words, want) + } + + singleton := []string{"one"} + Reverse(singleton) + if want := []string{"one"}; !Equal(singleton, want) { + t.Errorf("Reverse(singeleton) = %v, want %v", singleton, want) + } + + Reverse[[]string](nil) +} + // naiveReplace is a baseline implementation to the Replace function. func naiveReplace[S ~[]E, E any](s S, i, j int, v ...E) S { s = Delete(s, i, j) @@ -812,6 +922,33 @@ func TestReplacePanics(t *testing.T) { } } +func TestReplaceOverlap(t *testing.T) { + const N = 10 + a := make([]int, N) + want := make([]int, 2*N) + for n := 0; n <= N; n++ { // length + for i := 0; i <= n; i++ { // insertion point 1 + for j := i; j <= n; j++ { // insertion point 2 + for x := 0; x <= N; x++ { // start of inserted data + for y := x; y <= N; y++ { // end of inserted data + for k := 0; k < N; k++ { + a[k] = k + } + want = want[:0] + want = append(want, a[:i]...) + want = append(want, a[x:y]...) + want = append(want, a[j:n]...) + got := Replace(a[:n], i, j, a[x:y]...) + if !Equal(got, want) { + t.Errorf("Insert with overlap failed n=%d i=%d j=%d x=%d y=%d, got %v want %v", n, i, j, x, y, got, want) + } + } + } + } + } + } +} + func BenchmarkReplace(b *testing.B) { cases := []struct { name string @@ -860,3 +997,58 @@ func BenchmarkReplace(b *testing.B) { } } + +func TestRotate(t *testing.T) { + const N = 10 + s := make([]int, 0, N) + for n := 0; n < N; n++ { + for r := 0; r < n; r++ { + s = s[:0] + for i := 0; i < n; i++ { + s = append(s, i) + } + rotateLeft(s, r) + for i := 0; i < n; i++ { + if s[i] != (i+r)%n { + t.Errorf("expected n=%d r=%d i:%d want:%d got:%d", n, r, i, (i+r)%n, s[i]) + } + } + } + } +} + +func TestInsertGrowthRate(t *testing.T) { + b := make([]byte, 1) + maxCap := cap(b) + nGrow := 0 + const N = 1e6 + for i := 0; i < N; i++ { + b = Insert(b, len(b)-1, 0) + if cap(b) > maxCap { + maxCap = cap(b) + nGrow++ + } + } + want := int(math.Log(N) / math.Log(1.25)) // 1.25 == growth rate for large slices + if nGrow > want { + t.Errorf("too many grows. got:%d want:%d", nGrow, want) + } +} + +func TestReplaceGrowthRate(t *testing.T) { + b := make([]byte, 2) + maxCap := cap(b) + nGrow := 0 + const N = 1e6 + for i := 0; i < N; i++ { + b = Replace(b, len(b)-2, len(b)-1, 0, 0) + if cap(b) > maxCap { + maxCap = cap(b) + nGrow++ + } + } + want := int(math.Log(N) / math.Log(1.25)) // 1.25 == growth rate for large slices + if nGrow > want { + t.Errorf("too many grows. got:%d want:%d", nGrow, want) + } +} diff --git a/slices/sort.go b/slices/sort.go index 231b6448a..b67897f76 100644 --- a/slices/sort.go +++ b/slices/sort.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:generate go run $GOROOT/src/sort/gen_sort_variants.go -exp + package slices import ( @@ -11,57 +13,116 @@ import ( ) // Sort sorts a slice of any ordered type in ascending order. -// Sort may fail to sort correctly when sorting slices of floating-point -// numbers containing Not-a-number (NaN) values. -// Use slices.SortFunc(x, func(a, b float64) bool {return a < b || (math.IsNaN(a) && !math.IsNaN(b))}) -// instead if the input may contain NaNs. -func Sort[E constraints.Ordered](x []E) { +// When sorting floating-point numbers, NaNs are ordered before other values. +func Sort[S ~[]E, E constraints.Ordered](x S) { n := len(x) pdqsortOrdered(x, 0, n, bits.Len(uint(n))) } -// SortFunc sorts the slice x in ascending order as determined by the less function. -// This sort is not guaranteed to be stable. +// SortFunc sorts the slice x in ascending order as determined by the cmp +// function. This sort is not guaranteed to be stable. +// cmp(a, b) should return a negative number when a < b, a positive number when +// a > b and zero when a == b. // -// SortFunc requires that less is a strict weak ordering. +// SortFunc requires that cmp is a strict weak ordering. // See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. -func SortFunc[E any](x []E, less func(a, b E) bool) { +func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { n := len(x) - pdqsortLessFunc(x, 0, n, bits.Len(uint(n)), less) + pdqsortCmpFunc(x, 0, n, bits.Len(uint(n)), cmp) } // SortStableFunc sorts the slice x while keeping the original order of equal -// elements, using less to compare elements. -func SortStableFunc[E any](x []E, less func(a, b E) bool) { - stableLessFunc(x, len(x), less) +// elements, using cmp to compare elements in the same way as [SortFunc]. +func SortStableFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { + stableCmpFunc(x, len(x), cmp) } // IsSorted reports whether x is sorted in ascending order. -func IsSorted[E constraints.Ordered](x []E) bool { +func IsSorted[S ~[]E, E constraints.Ordered](x S) bool { for i := len(x) - 1; i > 0; i-- { - if x[i] < x[i-1] { + if cmpLess(x[i], x[i-1]) { return false } } return true } -// IsSortedFunc reports whether x is sorted in ascending order, with less as the -// comparison function. -func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool { +// IsSortedFunc reports whether x is sorted in ascending order, with cmp as the +// comparison function as defined by [SortFunc]. +func IsSortedFunc[S ~[]E, E any](x S, cmp func(a, b E) int) bool { for i := len(x) - 1; i > 0; i-- { - if less(x[i], x[i-1]) { + if cmp(x[i], x[i-1]) < 0 { return false } } return true } +// Min returns the minimal value in x. It panics if x is empty. +// For floating-point numbers, Min propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Min[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Min: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = min(m, x[i]) + } + return m +} + +// MinFunc returns the minimal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one minimal element +// according to the cmp function, MinFunc returns the first one. +func MinFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MinFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) < 0 { + m = x[i] + } + } + return m +} + +// Max returns the maximal value in x. It panics if x is empty. +// For floating-point E, Max propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Max[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Max: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = max(m, x[i]) + } + return m +} + +// MaxFunc returns the maximal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one maximal element +// according to the cmp function, MaxFunc returns the first one. +func MaxFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MaxFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) > 0 { + m = x[i] + } + } + return m +} + // BinarySearch searches for target in a sorted slice and returns the position // where target is found, or the position where target would appear in the // sort order; it also returns a bool saying whether the target is really found // in the slice. The slice must be sorted in increasing order. -func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { +func BinarySearch[S ~[]E, E constraints.Ordered](x S, target E) (int, bool) { // Inlining is faster than calling BinarySearchFunc with a lambda. n := len(x) // Define x[-1] < target and x[n] >= target. @@ -70,24 +131,24 @@ func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { for i < j { h := int(uint(i+j) >> 1) // avoid overflow when computing h // i ≤ h < j - if x[h] < target { + if cmpLess(x[h], target) { i = h + 1 // preserves x[i-1] < target } else { j = h // preserves x[j] >= target } } // i == j, x[i-1] < target, and x[j] (= x[i]) >= target => answer is i. - return i, i < n && x[i] == target + return i, i < n && (x[i] == target || (isNaN(x[i]) && isNaN(target))) } -// BinarySearchFunc works like BinarySearch, but uses a custom comparison +// BinarySearchFunc works like [BinarySearch], but uses a custom comparison // function. The slice must be sorted in increasing order, where "increasing" // is defined by cmp. cmp should return 0 if the slice element matches // the target, a negative number if the slice element precedes the target, // or a positive number if the slice element follows the target. // cmp must implement the same ordering as the slice, such that if // cmp(a, t) < 0 and cmp(b, t) >= 0, then a must precede b in the slice. -func BinarySearchFunc[E, T any](x []E, target T, cmp func(E, T) int) (int, bool) { +func BinarySearchFunc[S ~[]E, E, T any](x S, target T, cmp func(E, T) int) (int, bool) { n := len(x) // Define cmp(x[-1], target) < 0 and cmp(x[n], target) >= 0 . // Invariant: cmp(x[i - 1], target) < 0, cmp(x[j], target) >= 0. @@ -126,3 +187,9 @@ func (r *xorshift) Next() uint64 { func nextPowerOfTwo(length int) uint { return 1 << bits.Len(uint(length)) } + +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T constraints.Ordered](x T) bool { + return x != x +} diff --git a/slices/sort_benchmark_test.go b/slices/sort_benchmark_test.go index ee49f66bc..6aa03c81d 100644 --- a/slices/sort_benchmark_test.go +++ b/slices/sort_benchmark_test.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "sort" + "strconv" "strings" "testing" ) @@ -50,6 +51,15 @@ func BenchmarkSortInts(b *testing.B) { } } +func makeSortedStrings(n int) []string { + x := make([]string, n) + for i := 0; i < n; i++ { + x[i] = strconv.Itoa(i) + } + Sort(x) + return x +} + func BenchmarkSlicesSortInts(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() @@ -77,6 +87,24 @@ func BenchmarkSlicesSortInts_Reversed(b *testing.B) { } } +func BenchmarkIntsAreSorted(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + ints := makeSortedInts(N) + b.StartTimer() + sort.IntsAreSorted(ints) + } +} + +func BenchmarkIsSorted(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + ints := makeSortedInts(N) + b.StartTimer() + IsSorted(ints) + } +} + // Since we're benchmarking these sorts against each other, make sure that they // generate similar results. func TestIntSorts(t *testing.T) { @@ -96,7 +124,7 @@ func TestIntSorts(t *testing.T) { // The following is a benchmark for sorting strings. // makeRandomStrings generates n random strings with alphabetic runes of -// varying lenghts. +// varying lengths. func makeRandomStrings(n int) []string { rand.Seed(42) var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -135,6 +163,15 @@ func BenchmarkSortStrings(b *testing.B) { } } +func BenchmarkSortStrings_Sorted(b *testing.B) { + ss := makeSortedStrings(N) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sort.Strings(ss) + } +} + func BenchmarkSlicesSortStrings(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() @@ -144,6 +181,15 @@ func BenchmarkSlicesSortStrings(b *testing.B) { } } +func BenchmarkSlicesSortStrings_Sorted(b *testing.B) { + ss := makeSortedStrings(N) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + Sort(ss) + } +} + // These benchmarks compare sorting a slice of structs with sort.Sort vs. // slices.SortFunc. type myStruct struct { @@ -174,7 +220,7 @@ func TestStructSorts(t *testing.T) { } sort.Sort(ss) - SortFunc(ss2, func(a, b *myStruct) bool { return a.n < b.n }) + SortFunc(ss2, func(a, b *myStruct) int { return a.n - b.n }) for i := range ss { if *ss[i] != *ss2[i] { @@ -193,12 +239,12 @@ func BenchmarkSortStructs(b *testing.B) { } func BenchmarkSortFuncStructs(b *testing.B) { - lessFunc := func(a, b *myStruct) bool { return a.n < b.n } + cmpFunc := func(a, b *myStruct) int { return a.n - b.n } for i := 0; i < b.N; i++ { b.StopTimer() ss := makeRandomStructs(N) b.StartTimer() - SortFunc(ss, lessFunc) + SortFunc(ss, cmpFunc) } } diff --git a/slices/sort_test.go b/slices/sort_test.go index 3befa3e7e..501b5ee6d 100644 --- a/slices/sort_test.go +++ b/slices/sort_test.go @@ -5,6 +5,7 @@ package slices import ( + "fmt" "math" "math/rand" "sort" @@ -29,7 +30,7 @@ func TestSortIntSlice(t *testing.T) { func TestSortFuncIntSlice(t *testing.T) { data := Clone(ints[:]) - SortFunc(data, func(a, b int) bool { return a < b }) + SortFunc(data, func(a, b int) int { return a - b }) if !IsSorted(data) { t.Errorf("sorted %v", ints) t.Errorf(" got %v", data) @@ -47,17 +48,18 @@ func TestSortFloat64Slice(t *testing.T) { func TestSortFloat64SliceWithNaNs(t *testing.T) { data := float64sWithNaNs[:] - input := Clone(data) + data2 := Clone(data) - // Make sure Sort doesn't panic when the slice contains NaNs. Sort(data) - // Check whether the result is a permutation of the input. - sort.Float64s(data) - sort.Float64s(input) - for i, v := range input { - if data[i] != v && !(math.IsNaN(data[i]) && math.IsNaN(v)) { - t.Fatalf("the result is not a permutation of the input\ngot %v\nwant %v", data, input) - } + sort.Float64s(data2) + + if !IsSorted(data) { + t.Error("IsSorted indicates data isn't sorted") + } + + // Compare for equality using cmp.Compare, which considers NaNs equal. + if !EqualFunc(data, data2, func(a, b float64) bool { return cmpCompare(a, b) == 0 }) { + t.Errorf("mismatch between Sort and sort.Float64: got %v, want %v", data, data2) } } @@ -95,8 +97,8 @@ type intPair struct { type intPairs []intPair // Pairs compare on a only. -func intPairLess(x, y intPair) bool { - return x.a < y.a +func intPairCmp(x, y intPair) int { + return x.a - y.a } // Record initial order in B. @@ -134,12 +136,12 @@ func TestStability(t *testing.T) { for i := 0; i < len(data); i++ { data[i].a = rand.Intn(m) } - if IsSortedFunc(data, intPairLess) { + if IsSortedFunc(data, intPairCmp) { t.Fatalf("terrible rand.rand") } data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable didn't sort %d ints", n) } if !data.inOrder() { @@ -148,8 +150,8 @@ func TestStability(t *testing.T) { // already sorted data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable shuffled sorted %d ints (order)", n) } if !data.inOrder() { @@ -161,8 +163,8 @@ func TestStability(t *testing.T) { data[i].a = len(data) - i } data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable didn't sort %d ints", n) } if !data.inOrder() { @@ -170,6 +172,125 @@ func TestStability(t *testing.T) { } } +type S struct { + a int + b string +} + +func cmpS(s1, s2 S) int { + return cmpCompare(s1.a, s2.a) +} + +func TestMinMax(t *testing.T) { + intCmp := func(a, b int) int { return a - b } + + tests := []struct { + data []int + wantMin int + wantMax int + }{ + {[]int{7}, 7, 7}, + {[]int{1, 2}, 1, 2}, + {[]int{2, 1}, 1, 2}, + {[]int{1, 2, 3}, 1, 3}, + {[]int{3, 2, 1}, 1, 3}, + {[]int{2, 1, 3}, 1, 3}, + {[]int{2, 2, 3}, 2, 3}, + {[]int{3, 2, 3}, 2, 3}, + {[]int{0, 2, -9}, -9, 2}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) { + gotMin := Min(tt.data) + if gotMin != tt.wantMin { + t.Errorf("Min got %v, want %v", gotMin, tt.wantMin) + } + + gotMinFunc := MinFunc(tt.data, intCmp) + if gotMinFunc != tt.wantMin { + t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin) + } + + gotMax := Max(tt.data) + if gotMax != tt.wantMax { + t.Errorf("Max got %v, want %v", gotMax, tt.wantMax) + } + + gotMaxFunc := MaxFunc(tt.data, intCmp) + if gotMaxFunc != tt.wantMax { + t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax) + } + }) + } + + svals := []S{ + {1, "a"}, + {2, "a"}, + {1, "b"}, + {2, "b"}, + } + + gotMin := MinFunc(svals, cmpS) + wantMin := S{1, "a"} + if gotMin != wantMin { + t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin) + } + + gotMax := MaxFunc(svals, cmpS) + wantMax := S{2, "a"} + if gotMax != wantMax { + t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax) + } +} + +func TestMinMaxNaNs(t *testing.T) { + fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14} + if Min(fs) != -400.4 { + t.Errorf("got min %v, want -400.4", Min(fs)) + } + if Max(fs) != 999.9 { + t.Errorf("got max %v, want 999.9", Max(fs)) + } + + // No matter which element of fs is replaced with a NaN, both Min and Max + // should propagate the NaN to their output. + for i := 0; i < len(fs); i++ { + testfs := Clone(fs) + testfs[i] = math.NaN() + + fmin := Min(testfs) + if !math.IsNaN(fmin) { + t.Errorf("got min %v, want NaN", fmin) + } + + fmax := Max(testfs) + if !math.IsNaN(fmax) { + t.Errorf("got max %v, want NaN", fmax) + } + } +} + +func TestMinMaxPanics(t *testing.T) { + intCmp := func(a, b int) int { return a - b } + emptySlice := []int{} + + if !panics(func() { Min(emptySlice) }) { + t.Errorf("Min([]): got no panic, want panic") + } + + if !panics(func() { Max(emptySlice) }) { + t.Errorf("Max([]): got no panic, want panic") + } + + if !panics(func() { MinFunc(emptySlice, intCmp) }) { + t.Errorf("MinFunc([]): got no panic, want panic") + } + + if !panics(func() { MaxFunc(emptySlice, intCmp) }) { + t.Errorf("MaxFunc([]): got no panic, want panic") + } +} + func TestBinarySearch(t *testing.T) { str1 := []string{"foo"} str2 := []string{"ab", "ca"} @@ -282,6 +403,32 @@ func TestBinarySearchInts(t *testing.T) { } } +func TestBinarySearchFloats(t *testing.T) { + data := []float64{math.NaN(), -0.25, 0.0, 1.4} + tests := []struct { + target float64 + wantPos int + wantFound bool + }{ + {math.NaN(), 0, true}, + {math.Inf(-1), 1, false}, + {-0.25, 1, true}, + {0.0, 2, true}, + {1.4, 3, true}, + {1.5, 4, false}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) { + { + pos, found := BinarySearch(data, tt.target) + if pos != tt.wantPos || found != tt.wantFound { + t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) + } + } + }) + } +} + func TestBinarySearchFunc(t *testing.T) { data := []int{1, 10, 11, 2} // sorted lexicographically cmp := func(a int, b string) int { diff --git a/slices/zsortfunc.go b/slices/zsortanyfunc.go similarity index 64% rename from slices/zsortfunc.go rename to slices/zsortanyfunc.go index 2a632476c..06f2c7a24 100644 --- a/slices/zsortfunc.go +++ b/slices/zsortanyfunc.go @@ -6,28 +6,28 @@ package slices -// insertionSortLessFunc sorts data[a:b] using insertion sort. -func insertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +// insertionSortCmpFunc sorts data[a:b] using insertion sort. +func insertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { for i := a + 1; i < b; i++ { - for j := i; j > a && less(data[j], data[j-1]); j-- { + for j := i; j > a && (cmp(data[j], data[j-1]) < 0); j-- { data[j], data[j-1] = data[j-1], data[j] } } } -// siftDownLessFunc implements the heap property on data[lo:hi]. +// siftDownCmpFunc implements the heap property on data[lo:hi]. // first is an offset into the array where the root of the heap lies. -func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool) { +func siftDownCmpFunc[E any](data []E, lo, hi, first int, cmp func(a, b E) int) { root := lo for { child := 2*root + 1 if child >= hi { break } - if child+1 < hi && less(data[first+child], data[first+child+1]) { + if child+1 < hi && (cmp(data[first+child], data[first+child+1]) < 0) { child++ } - if !less(data[first+root], data[first+child]) { + if !(cmp(data[first+root], data[first+child]) < 0) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -35,30 +35,30 @@ func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool } } -func heapSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func heapSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { first := a lo := 0 hi := b - a // Build heap with greatest element at top. for i := (hi - 1) / 2; i >= 0; i-- { - siftDownLessFunc(data, i, hi, first, less) + siftDownCmpFunc(data, i, hi, first, cmp) } // Pop elements, largest first, into end of data. for i := hi - 1; i >= 0; i-- { data[first], data[first+i] = data[first+i], data[first] - siftDownLessFunc(data, lo, i, first, less) + siftDownCmpFunc(data, lo, i, first, cmp) } } -// pdqsortLessFunc sorts data[a:b]. +// pdqsortCmpFunc sorts data[a:b]. // The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. // pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf // C++ implementation: https://github.com/orlp/pdqsort // Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ // limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. -func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { +func pdqsortCmpFunc[E any](data []E, a, b, limit int, cmp func(a, b E) int) { const maxInsertion = 12 var ( @@ -70,25 +70,25 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { length := b - a if length <= maxInsertion { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) return } // Fall back to heapsort if too many bad choices were made. if limit == 0 { - heapSortLessFunc(data, a, b, less) + heapSortCmpFunc(data, a, b, cmp) return } // If the last partitioning was imbalanced, we need to breaking patterns. if !wasBalanced { - breakPatternsLessFunc(data, a, b, less) + breakPatternsCmpFunc(data, a, b, cmp) limit-- } - pivot, hint := choosePivotLessFunc(data, a, b, less) + pivot, hint := choosePivotCmpFunc(data, a, b, cmp) if hint == decreasingHint { - reverseRangeLessFunc(data, a, b, less) + reverseRangeCmpFunc(data, a, b, cmp) // The chosen pivot was pivot-a elements after the start of the array. // After reversing it is pivot-a elements before the end of the array. // The idea came from Rust's implementation. @@ -98,48 +98,48 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { // The slice is likely already sorted. if wasBalanced && wasPartitioned && hint == increasingHint { - if partialInsertionSortLessFunc(data, a, b, less) { + if partialInsertionSortCmpFunc(data, a, b, cmp) { return } } // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !less(data[a-1], data[pivot]) { - mid := partitionEqualLessFunc(data, a, b, pivot, less) + if a > 0 && !(cmp(data[a-1], data[pivot]) < 0) { + mid := partitionEqualCmpFunc(data, a, b, pivot, cmp) a = mid continue } - mid, alreadyPartitioned := partitionLessFunc(data, a, b, pivot, less) + mid, alreadyPartitioned := partitionCmpFunc(data, a, b, pivot, cmp) wasPartitioned = alreadyPartitioned leftLen, rightLen := mid-a, b-mid balanceThreshold := length / 8 if leftLen < rightLen { wasBalanced = leftLen >= balanceThreshold - pdqsortLessFunc(data, a, mid, limit, less) + pdqsortCmpFunc(data, a, mid, limit, cmp) a = mid + 1 } else { wasBalanced = rightLen >= balanceThreshold - pdqsortLessFunc(data, mid+1, b, limit, less) + pdqsortCmpFunc(data, mid+1, b, limit, cmp) b = mid } } } -// partitionLessFunc does one quicksort partition. +// partitionCmpFunc does one quicksort partition. // Let p = data[pivot] // Moves elements in data[a:b] around, so that data[i]

=p for inewpivot. // On return, data[newpivot] = p -func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int, alreadyPartitioned bool) { +func partitionCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int, alreadyPartitioned bool) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -151,10 +151,10 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) j-- for { - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -168,17 +168,17 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) return j, false } -// partitionEqualLessFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. +// partitionEqualCmpFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. // It assumed that data[a:b] does not contain elements smaller than the data[pivot]. -func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int) { +func partitionEqualCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !less(data[a], data[i]) { + for i <= j && !(cmp(data[a], data[i]) < 0) { i++ } - for i <= j && less(data[a], data[j]) { + for i <= j && (cmp(data[a], data[j]) < 0) { j-- } if i > j { @@ -191,15 +191,15 @@ func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) return i } -// partialInsertionSortLessFunc partially sorts a slice, returns true if the slice is sorted at the end. -func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) bool { +// partialInsertionSortCmpFunc partially sorts a slice, returns true if the slice is sorted at the end. +func partialInsertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) bool { const ( maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted shortestShifting = 50 // don't shift any elements on short arrays ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !less(data[i], data[i-1]) { + for i < b && !(cmp(data[i], data[i-1]) < 0) { i++ } @@ -216,7 +216,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -225,7 +225,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -235,9 +235,9 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b return false } -// breakPatternsLessFunc scatters some elements around in an attempt to break some patterns +// breakPatternsCmpFunc scatters some elements around in an attempt to break some patterns // that might cause imbalanced partitions in quicksort. -func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func breakPatternsCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { length := b - a if length >= 8 { random := xorshift(length) @@ -253,12 +253,12 @@ func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -// choosePivotLessFunc chooses a pivot in data[a:b]. +// choosePivotCmpFunc chooses a pivot in data[a:b]. // // [0,8): chooses a static pivot. // [8,shortestNinther): uses the simple median-of-three method. // [shortestNinther,∞): uses the Tukey ninther method. -func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (pivot int, hint sortedHint) { +func choosePivotCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) (pivot int, hint sortedHint) { const ( shortestNinther = 50 maxSwaps = 4 * 3 @@ -276,12 +276,12 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv if l >= 8 { if l >= shortestNinther { // Tukey ninther method, the idea came from Rust's implementation. - i = medianAdjacentLessFunc(data, i, &swaps, less) - j = medianAdjacentLessFunc(data, j, &swaps, less) - k = medianAdjacentLessFunc(data, k, &swaps, less) + i = medianAdjacentCmpFunc(data, i, &swaps, cmp) + j = medianAdjacentCmpFunc(data, j, &swaps, cmp) + k = medianAdjacentCmpFunc(data, k, &swaps, cmp) } // Find the median among i, j, k and stores it into j. - j = medianLessFunc(data, i, j, k, &swaps, less) + j = medianCmpFunc(data, i, j, k, &swaps, cmp) } switch swaps { @@ -294,29 +294,29 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv } } -// order2LessFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. -func order2LessFunc[E any](data []E, a, b int, swaps *int, less func(a, b E) bool) (int, int) { - if less(data[b], data[a]) { +// order2CmpFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. +func order2CmpFunc[E any](data []E, a, b int, swaps *int, cmp func(a, b E) int) (int, int) { + if cmp(data[b], data[a]) < 0 { *swaps++ return b, a } return a, b } -// medianLessFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. -func medianLessFunc[E any](data []E, a, b, c int, swaps *int, less func(a, b E) bool) int { - a, b = order2LessFunc(data, a, b, swaps, less) - b, c = order2LessFunc(data, b, c, swaps, less) - a, b = order2LessFunc(data, a, b, swaps, less) +// medianCmpFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. +func medianCmpFunc[E any](data []E, a, b, c int, swaps *int, cmp func(a, b E) int) int { + a, b = order2CmpFunc(data, a, b, swaps, cmp) + b, c = order2CmpFunc(data, b, c, swaps, cmp) + a, b = order2CmpFunc(data, a, b, swaps, cmp) return b } -// medianAdjacentLessFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. -func medianAdjacentLessFunc[E any](data []E, a int, swaps *int, less func(a, b E) bool) int { - return medianLessFunc(data, a-1, a, a+1, swaps, less) +// medianAdjacentCmpFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. +func medianAdjacentCmpFunc[E any](data []E, a int, swaps *int, cmp func(a, b E) int) int { + return medianCmpFunc(data, a-1, a, a+1, swaps, cmp) } -func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func reverseRangeCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { i := a j := b - 1 for i < j { @@ -326,37 +326,37 @@ func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -func swapRangeLessFunc[E any](data []E, a, b, n int, less func(a, b E) bool) { +func swapRangeCmpFunc[E any](data []E, a, b, n int, cmp func(a, b E) int) { for i := 0; i < n; i++ { data[a+i], data[b+i] = data[b+i], data[a+i] } } -func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { +func stableCmpFunc[E any](data []E, n int, cmp func(a, b E) int) { blockSize := 20 // must be > 0 a, b := 0, blockSize for b <= n { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) a = b b += blockSize } - insertionSortLessFunc(data, a, n, less) + insertionSortCmpFunc(data, a, n, cmp) for blockSize < n { a, b = 0, 2*blockSize for b <= n { - symMergeLessFunc(data, a, a+blockSize, b, less) + symMergeCmpFunc(data, a, a+blockSize, b, cmp) a = b b += 2 * blockSize } if m := a + blockSize; m < n { - symMergeLessFunc(data, a, m, n, less) + symMergeCmpFunc(data, a, m, n, cmp) } blockSize *= 2 } } -// symMergeLessFunc merges the two sorted subsequences data[a:m] and data[m:b] using +// symMergeCmpFunc merges the two sorted subsequences data[a:m] and data[m:b] using // the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum // Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz // Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in @@ -375,7 +375,7 @@ func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { // symMerge assumes non-degenerate arguments: a < m && m < b. // Having the caller check this condition eliminates many leaf recursion calls, // which improves performance. -func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func symMergeCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { // Avoid unnecessary recursions of symMerge // by direct insertion of data[a] into data[m:b] // if data[a:m] only contains one element. @@ -387,7 +387,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := b for i < j { h := int(uint(i+j) >> 1) - if less(data[h], data[a]) { + if cmp(data[h], data[a]) < 0 { i = h + 1 } else { j = h @@ -411,7 +411,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := m for i < j { h := int(uint(i+j) >> 1) - if !less(data[m], data[h]) { + if !(cmp(data[m], data[h]) < 0) { i = h + 1 } else { j = h @@ -438,7 +438,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { for start < r { c := int(uint(start+r) >> 1) - if !less(data[p-c], data[c]) { + if !(cmp(data[p-c], data[c]) < 0) { start = c + 1 } else { r = c @@ -447,33 +447,33 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { end := n - start if start < m && m < end { - rotateLessFunc(data, start, m, end, less) + rotateCmpFunc(data, start, m, end, cmp) } if a < start && start < mid { - symMergeLessFunc(data, a, start, mid, less) + symMergeCmpFunc(data, a, start, mid, cmp) } if mid < end && end < b { - symMergeLessFunc(data, mid, end, b, less) + symMergeCmpFunc(data, mid, end, b, cmp) } } -// rotateLessFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: +// rotateCmpFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: // Data of the form 'x u v y' is changed to 'x v u y'. // rotate performs at most b-a many calls to data.Swap, // and it assumes non-degenerate arguments: a < m && m < b. -func rotateLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func rotateCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { i := m - a j := b - m for i != j { if i > j { - swapRangeLessFunc(data, m-i, m, j, less) + swapRangeCmpFunc(data, m-i, m, j, cmp) i -= j } else { - swapRangeLessFunc(data, m-i, m+j-i, i, less) + swapRangeCmpFunc(data, m-i, m+j-i, i, cmp) j -= i } } // i == j - swapRangeLessFunc(data, m-i, m, i, less) + swapRangeCmpFunc(data, m-i, m, i, cmp) } diff --git a/slices/zsortordered.go b/slices/zsortordered.go index efaa1c8b7..99b47c398 100644 --- a/slices/zsortordered.go +++ b/slices/zsortordered.go @@ -11,7 +11,7 @@ import "golang.org/x/exp/constraints" // insertionSortOrdered sorts data[a:b] using insertion sort. func insertionSortOrdered[E constraints.Ordered](data []E, a, b int) { for i := a + 1; i < b; i++ { - for j := i; j > a && (data[j] < data[j-1]); j-- { + for j := i; j > a && cmpLess(data[j], data[j-1]); j-- { data[j], data[j-1] = data[j-1], data[j] } } @@ -26,10 +26,10 @@ func siftDownOrdered[E constraints.Ordered](data []E, lo, hi, first int) { if child >= hi { break } - if child+1 < hi && (data[first+child] < data[first+child+1]) { + if child+1 < hi && cmpLess(data[first+child], data[first+child+1]) { child++ } - if !(data[first+root] < data[first+child]) { + if !cmpLess(data[first+root], data[first+child]) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -107,7 +107,7 @@ func pdqsortOrdered[E constraints.Ordered](data []E, a, b, limit int) { // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !(data[a-1] < data[pivot]) { + if a > 0 && !cmpLess(data[a-1], data[pivot]) { mid := partitionEqualOrdered(data, a, b, pivot) a = mid continue @@ -138,10 +138,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -153,10 +153,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo j-- for { - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -177,10 +177,10 @@ func partitionEqualOrdered[E constraints.Ordered](data []E, a, b, pivot int) (ne i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !(data[a] < data[i]) { + for i <= j && !cmpLess(data[a], data[i]) { i++ } - for i <= j && (data[a] < data[j]) { + for i <= j && cmpLess(data[a], data[j]) { j-- } if i > j { @@ -201,7 +201,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !(data[i] < data[i-1]) { + for i < b && !cmpLess(data[i], data[i-1]) { i++ } @@ -218,7 +218,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -227,7 +227,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -298,7 +298,7 @@ func choosePivotOrdered[E constraints.Ordered](data []E, a, b int) (pivot int, h // order2Ordered returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. func order2Ordered[E constraints.Ordered](data []E, a, b int, swaps *int) (int, int) { - if data[b] < data[a] { + if cmpLess(data[b], data[a]) { *swaps++ return b, a } @@ -389,7 +389,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := b for i < j { h := int(uint(i+j) >> 1) - if data[h] < data[a] { + if cmpLess(data[h], data[a]) { i = h + 1 } else { j = h @@ -413,7 +413,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := m for i < j { h := int(uint(i+j) >> 1) - if !(data[m] < data[h]) { + if !cmpLess(data[m], data[h]) { i = h + 1 } else { j = h @@ -440,7 +440,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { for start < r { c := int(uint(start+r) >> 1) - if !(data[p-c] < data[c]) { + if !cmpLess(data[p-c], data[c]) { start = c + 1 } else { r = c