From 302865e7556b4ae5de27248ce625d443ef4ad3ed Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Thu, 20 Jul 2023 21:36:10 -0700 Subject: [PATCH] slices: update to current standard library version Update x/exp/slices to the current standard library slices package, while retaining the ability to use it with Go 1.18 through Go 1.20. Note that this changes some of the sorting functions to use a comparison function rather than a less function. We don't promise backward compatibility in x/exp packages. Being compatible with the Go 1.21 package seems more useful for people not yet using 1.21, as it will make the transition to 1.21 easier. The generated files were built using "go generate" with a GOROOT that included CL 511660. Fixes golang/go#61374 Change-Id: I4abfd9db92d553f554aec83d60f0c13fa56c1d8e Reviewed-on: https://go-review.googlesource.com/c/exp/+/511895 TryBot-Result: Gopher Robot Auto-Submit: Ian Lance Taylor Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor Run-TryBot: Ian Lance Taylor Reviewed-by: Eli Bendersky --- slices/cmp.go | 44 +++ slices/slices.go | 353 ++++++++++++++++++----- slices/slices_test.go | 232 +++++++++++++-- slices/sort.go | 115 ++++++-- slices/sort_benchmark_test.go | 54 +++- slices/sort_test.go | 185 ++++++++++-- slices/{zsortfunc.go => zsortanyfunc.go} | 154 +++++----- slices/zsortordered.go | 34 +-- 8 files changed, 942 insertions(+), 229 deletions(-) create mode 100644 slices/cmp.go rename slices/{zsortfunc.go => zsortanyfunc.go} (64%) diff --git a/slices/cmp.go b/slices/cmp.go new file mode 100644 index 000000000..fbf1934a0 --- /dev/null +++ b/slices/cmp.go @@ -0,0 +1,44 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package slices + +import "golang.org/x/exp/constraints" + +// min is a version of the predeclared function from the Go 1.21 release. +func min[T constraints.Ordered](a, b T) T { + if a < b || isNaN(a) { + return a + } + return b +} + +// max is a version of the predeclared function from the Go 1.21 release. +func max[T constraints.Ordered](a, b T) T { + if a > b || isNaN(a) { + return a + } + return b +} + +// cmpLess is a copy of cmp.Less from the Go 1.21 release. +func cmpLess[T constraints.Ordered](x, y T) bool { + return (isNaN(x) && !isNaN(y)) || x < y +} + +// cmpCompare is a copy of cmp.Compare from the Go 1.21 release. +func cmpCompare[T constraints.Ordered](x, y T) int { + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} diff --git a/slices/slices.go b/slices/slices.go index 8a7cf20db..5e8158bba 100644 --- a/slices/slices.go +++ b/slices/slices.go @@ -3,23 +3,20 @@ // license that can be found in the LICENSE file. // Package slices defines various functions useful with slices of any type. -// Unless otherwise specified, these functions all apply to the elements -// of a slice at index 0 <= i < len(s). -// -// Note that the less function in IsSortedFunc, SortFunc, SortStableFunc requires a -// strict weak ordering (https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings), -// or the sorting may fail to sort correctly. A common case is when sorting slices of -// floating-point numbers containing NaN values. package slices -import "golang.org/x/exp/constraints" +import ( + "unsafe" + + "golang.org/x/exp/constraints" +) // Equal reports whether two slices are equal: the same length and all // elements equal. If the lengths are different, Equal returns false. // Otherwise, the elements are compared in increasing index order, and the // comparison stops at the first unequal pair. // Floating point NaNs are not considered equal. -func Equal[E comparable](s1, s2 []E) bool { +func Equal[S ~[]E, E comparable](s1, s2 S) bool { if len(s1) != len(s2) { return false } @@ -31,12 +28,12 @@ func Equal[E comparable](s1, s2 []E) bool { return true } -// EqualFunc reports whether two slices are equal using a comparison +// EqualFunc reports whether two slices are equal using an equality // function on each pair of elements. If the lengths are different, // EqualFunc returns false. Otherwise, the elements are compared in // increasing index order, and the comparison stops at the first index // for which eq returns false. -func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { +func EqualFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, eq func(E1, E2) bool) bool { if len(s1) != len(s2) { return false } @@ -49,45 +46,37 @@ func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool { return true } -// Compare compares the elements of s1 and s2. -// The elements are compared sequentially, starting at index 0, +// Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair +// of elements. The elements are compared sequentially, starting at index 0, // until one element is not equal to the other. // The result of comparing the first non-matching elements is returned. // If both slices are equal until one of them ends, the shorter slice is // considered less than the longer one. // The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. -// Comparisons involving floating point NaNs are ignored. -func Compare[E constraints.Ordered](s1, s2 []E) int { - s2len := len(s2) +func Compare[S ~[]E, E constraints.Ordered](s1, s2 S) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] - switch { - case v1 < v2: - return -1 - case v1 > v2: - return +1 + if c := cmpCompare(v1, v2); c != 0 { + return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 } -// CompareFunc is like Compare but uses a comparison function -// on each pair of elements. The elements are compared in increasing -// index order, and the comparisons stop after the first time cmp -// returns non-zero. +// CompareFunc is like [Compare] but uses a custom comparison function on each +// pair of elements. // The result is the first non-zero result of cmp; if cmp always // returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), // and +1 if len(s1) > len(s2). -func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { - s2len := len(s2) +func CompareFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, cmp func(E1, E2) int) int { for i, v1 := range s1 { - if i >= s2len { + if i >= len(s2) { return +1 } v2 := s2[i] @@ -95,7 +84,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { return c } } - if len(s1) < s2len { + if len(s1) < len(s2) { return -1 } return 0 @@ -103,7 +92,7 @@ func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int { // Index returns the index of the first occurrence of v in s, // or -1 if not present. -func Index[E comparable](s []E, v E) int { +func Index[S ~[]E, E comparable](s S, v E) int { for i := range s { if v == s[i] { return i @@ -114,7 +103,7 @@ func Index[E comparable](s []E, v E) int { // IndexFunc returns the first index i satisfying f(s[i]), // or -1 if none do. -func IndexFunc[E any](s []E, f func(E) bool) int { +func IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { for i := range s { if f(s[i]) { return i @@ -124,39 +113,104 @@ func IndexFunc[E any](s []E, f func(E) bool) int { } // Contains reports whether v is present in s. -func Contains[E comparable](s []E, v E) bool { +func Contains[S ~[]E, E comparable](s S, v E) bool { return Index(s, v) >= 0 } // ContainsFunc reports whether at least one // element e of s satisfies f(e). -func ContainsFunc[E any](s []E, f func(E) bool) bool { +func ContainsFunc[S ~[]E, E any](s S, f func(E) bool) bool { return IndexFunc(s, f) >= 0 } // Insert inserts the values v... into s at index i, // returning the modified slice. -// In the returned slice r, r[i] == v[0]. +// The elements at s[i:] are shifted up to make room. +// In the returned slice r, r[i] == v[0], +// and r[i+len(v)] == value originally at r[i]. // Insert panics if i is out of range. // This function is O(len(s) + len(v)). func Insert[S ~[]E, E any](s S, i int, v ...E) S { - tot := len(s) + len(v) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[i:]) + m := len(v) + if m == 0 { + return s + } + n := len(s) + if i == n { + return append(s, v...) + } + if n+m > cap(s) { + // Use append rather than make so that we bump the size of + // the slice up to the next storage class. + // This is what Grow does but we don't call Grow because + // that might copy the values twice. + s2 := append(s[:i], make(S, n+m-i)...) copy(s2[i:], v) + copy(s2[i+m:], s[i:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[i:]) - return s2 + s = s[:n+m] + + // before: + // s: aaaaaaaabbbbccccccccdddd + // ^ ^ ^ ^ + // i i+m n n+m + // after: + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // + // a are the values that don't move in s. + // v are the values copied in from v. + // b and c are the values from s that are shifted up in index. + // d are the values that get overwritten, never to be seen again. + + if !overlaps(v, s[i+m:]) { + // Easy case - v does not overlap either the c or d regions. + // (It might be in some of a or b, or elsewhere entirely.) + // The data we copy up doesn't write to v at all, so just do it. + + copy(s[i+m:], s[i:]) + + // Now we have + // s: aaaaaaaabbbbbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // Note the b values are duplicated. + + copy(s[i:], v) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s + } + + // The hard case - v overlaps c or d. We can't just shift up + // the data because we'd move or clobber the values we're trying + // to insert. + // So instead, write v on top of d, then rotate. + copy(s[n:], v) + + // Now we have + // s: aaaaaaaabbbbccccccccvvvv + // ^ ^ ^ ^ + // i i+m n n+m + + rotateRight(s[i:], m) + + // Now we have + // s: aaaaaaaavvvvbbbbcccccccc + // ^ ^ ^ ^ + // i i+m n n+m + // That's the result we want. + return s } // Delete removes the elements s[i:j] from s, returning the modified slice. // Delete panics if s[i:j] is not a valid slice of s. -// Delete modifies the contents of the slice s; it does not create a new slice. // Delete is O(len(s)-j), so if many items must be deleted, it is better to // make a single call deleting them all together than to delete one at a time. // Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those @@ -175,39 +229,106 @@ func Delete[S ~[]E, E any](s S, i, j int) S { // zeroing those elements so that objects they reference can be garbage // collected. func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := IndexFunc(s, del) + if i == -1 { + return s + } // Don't start copying elements until we find one to delete. - for i, v := range s { - if del(v) { - j := i - for i++; i < len(s); i++ { - v = s[i] - if !del(v) { - s[j] = v - j++ - } - } - return s[:j] + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ } } - return s + return s[:i] } // Replace replaces the elements s[i:j] by the given v, and returns the // modified slice. Replace panics if s[i:j] is not a valid slice of s. func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { _ = s[i:j] // verify that i:j is a valid subslice + + if i == j { + return Insert(s, i, v...) + } + if j == len(s) { + return append(s[:i], v...) + } + tot := len(s[:i]) + len(v) + len(s[j:]) - if tot <= cap(s) { - s2 := s[:tot] - copy(s2[i+len(v):], s[j:]) + if tot > cap(s) { + // Too big to fit, allocate and copy over. + s2 := append(s[:i], make(S, tot-i)...) // See Insert copy(s2[i:], v) + copy(s2[i+len(v):], s[j:]) return s2 } - s2 := make(S, tot) - copy(s2, s[:i]) - copy(s2[i:], v) - copy(s2[i+len(v):], s[j:]) - return s2 + + r := s[:tot] + + if i+len(v) <= j { + // Easy, as v fits in the deleted portion. + copy(r[i:], v) + if i+len(v) != j { + copy(r[i+len(v):], s[j:]) + } + return r + } + + // We are expanding (v is bigger than j-i). + // The situation is something like this: + // (example has i=4,j=8,len(s)=16,len(v)=6) + // s: aaaaxxxxbbbbbbbbyy + // ^ ^ ^ ^ + // i j len(s) tot + // a: prefix of s + // x: deleted range + // b: more of s + // y: area to expand into + + if !overlaps(r[i+len(v):], v) { + // Easy, as v is not clobbered by the first copy. + copy(r[i+len(v):], s[j:]) + copy(r[i:], v) + return r + } + + // This is a situation where we don't have a single place to which + // we can copy v. Parts of it need to go to two different places. + // We want to copy the prefix of v into y and the suffix into x, then + // rotate |y| spots to the right. + // + // v[2:] v[:2] + // | | + // s: aaaavvvvbbbbbbbbvv + // ^ ^ ^ ^ + // i j len(s) tot + // + // If either of those two destinations don't alias v, then we're good. + y := len(v) - (j - i) // length of y portion + + if !overlaps(r[i:j], v) { + copy(r[i:j], v[y:]) + copy(r[len(s):], v[:y]) + rotateRight(r[i:], y) + return r + } + if !overlaps(r[len(s):], v) { + copy(r[len(s):], v[:y]) + copy(r[i:j], v[y:]) + rotateRight(r[i:], y) + return r + } + + // Now we know that v overlaps both x and y. + // That means that the entirety of b is *inside* v. + // So we don't need to preserve b at all; instead we + // can copy v first, then copy the b part of v out of + // v to the right destination. + k := startIdx(v, s[j:]) + copy(r[i:], v) + copy(r[i+len(v):], r[i+k:]) + return r } // Clone returns a copy of the slice. @@ -222,7 +343,8 @@ func Clone[S ~[]E, E any](s S) S { // Compact replaces consecutive runs of equal elements with a single copy. // This is like the uniq command found on Unix. -// Compact modifies the contents of the slice s; it does not create a new slice. +// Compact modifies the contents of the slice s and returns the modified slice, +// which may have a smaller length. // When Compact discards m elements in total, it might not modify the elements // s[len(s)-m:len(s)]. If those elements contain pointers you might consider // zeroing those elements so that objects they reference can be garbage collected. @@ -242,7 +364,8 @@ func Compact[S ~[]E, E comparable](s S) S { return s[:i] } -// CompactFunc is like Compact but uses a comparison function. +// CompactFunc is like [Compact] but uses an equality function to compare elements. +// For runs of elements that compare equal, CompactFunc keeps the first one. func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { if len(s) < 2 { return s @@ -280,3 +403,97 @@ func Grow[S ~[]E, E any](s S, n int) S { func Clip[S ~[]E, E any](s S) S { return s[:len(s):len(s)] } + +// Rotation algorithm explanation: +// +// rotate left by 2 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join first parts +// 89234567 01 +// recursively rotate first left part by 2 +// 23456789 01 +// join at the end +// 2345678901 +// +// rotate left by 8 +// start with +// 0123456789 +// split up like this +// 01 234567 89 +// swap first 2 and last 2 +// 89 234567 01 +// join last parts +// 89 23456701 +// recursively rotate second part left by 6 +// 89 01234567 +// join at the end +// 8901234567 + +// TODO: There are other rotate algorithms. +// This algorithm has the desirable property that it moves each element exactly twice. +// The triple-reverse algorithm is simpler and more cache friendly, but takes more writes. +// The follow-cycles algorithm can be 1-write but it is not very cache friendly. + +// rotateLeft rotates b left by n spaces. +// s_final[i] = s_orig[i+r], wrapping around. +func rotateLeft[E any](s []E, r int) { + for r != 0 && r != len(s) { + if r*2 <= len(s) { + swap(s[:r], s[len(s)-r:]) + s = s[:len(s)-r] + } else { + swap(s[:len(s)-r], s[r:]) + s, r = s[len(s)-r:], r*2-len(s) + } + } +} +func rotateRight[E any](s []E, r int) { + rotateLeft(s, len(s)-r) +} + +// swap swaps the contents of x and y. x and y must be equal length and disjoint. +func swap[E any](x, y []E) { + for i := 0; i < len(x); i++ { + x[i], y[i] = y[i], x[i] + } +} + +// overlaps reports whether the memory ranges a[0:len(a)] and b[0:len(b)] overlap. +func overlaps[E any](a, b []E) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + elemSize := unsafe.Sizeof(a[0]) + if elemSize == 0 { + return false + } + // TODO: use a runtime/unsafe facility once one becomes available. See issue 12445. + // Also see crypto/internal/alias/alias.go:AnyOverlap + return uintptr(unsafe.Pointer(&a[0])) <= uintptr(unsafe.Pointer(&b[len(b)-1]))+(elemSize-1) && + uintptr(unsafe.Pointer(&b[0])) <= uintptr(unsafe.Pointer(&a[len(a)-1]))+(elemSize-1) +} + +// startIdx returns the index in haystack where the needle starts. +// prerequisite: the needle must be aliased entirely inside the haystack. +func startIdx[E any](haystack, needle []E) int { + p := &needle[0] + for i := range haystack { + if p == &haystack[i] { + return i + } + } + // TODO: what if the overlap is by a non-integral number of Es? + panic("needle not found") +} + +// Reverse reverses the elements of the slice in place. +func Reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/slices/slices_test.go b/slices/slices_test.go index c2402dd76..7371bdb88 100644 --- a/slices/slices_test.go +++ b/slices/slices_test.go @@ -8,8 +8,6 @@ import ( "math" "strings" "testing" - - "golang.org/x/exp/constraints" ) var raceEnabled bool @@ -84,7 +82,7 @@ func equalNaN[T comparable](v1, v2 T) bool { } // offByOne returns true if integers v1 and v2 differ by 1. -func offByOne[E constraints.Integer](v1, v2 E) bool { +func offByOne(v1, v2 int) bool { return v1 == v2+1 || v1 == v2-1 } @@ -105,10 +103,10 @@ func TestEqualFunc(t *testing.T) { s1 := []int{1, 2, 3} s2 := []int{2, 3, 4} - if EqualFunc(s1, s1, offByOne[int]) { + if EqualFunc(s1, s1, offByOne) { t.Errorf("EqualFunc(%v, %v, offByOne) = true, want false", s1, s1) } - if !EqualFunc(s1, s2, offByOne[int]) { + if !EqualFunc(s1, s2, offByOne) { t.Errorf("EqualFunc(%v, %v, offByOne) = false, want true", s1, s2) } @@ -140,6 +138,31 @@ var compareIntTests = []struct { s1, s2 []int want int }{ + { + []int{1}, + []int{1}, + 0, + }, + { + []int{1}, + []int{}, + 1, + }, + { + []int{}, + []int{1}, + -1, + }, + { + []int{}, + []int{}, + 0, + }, + { + []int{1, 2, 3}, + []int{1, 2, 3}, + 0, + }, { []int{1, 2, 3}, []int{1, 2, 3, 4}, @@ -160,12 +183,32 @@ var compareIntTests = []struct { []int{1, 2, 3}, +1, }, + { + []int{1, 4, 3}, + []int{1, 2, 3, 8, 9}, + +1, + }, } var compareFloatTests = []struct { s1, s2 []float64 want int }{ + { + []float64{}, + []float64{}, + 0, + }, + { + []float64{1}, + []float64{1}, + 0, + }, + { + []float64{math.NaN()}, + []float64{math.NaN()}, + 0, + }, { []float64{1, 2, math.NaN()}, []float64{1, 2, math.NaN()}, @@ -184,13 +227,23 @@ var compareFloatTests = []struct { { []float64{1, math.NaN(), 3}, []float64{1, 2, math.NaN()}, - 0, + -1, }, { - []float64{1, math.NaN(), 3, 4}, + []float64{1, 2, 3}, []float64{1, 2, math.NaN()}, +1, }, + { + []float64{1, 2, 3}, + []float64{1, math.NaN(), 3}, + +1, + }, + { + []float64{1, math.NaN(), 3, 4}, + []float64{1, 2, math.NaN()}, + -1, + }, } func TestCompare(t *testing.T) { @@ -232,16 +285,6 @@ func equalToCmp[T comparable](eq func(T, T) bool) func(T, T) int { } } -func cmp[T constraints.Ordered](v1, v2 T) int { - if v1 < v2 { - return -1 - } else if v1 > v2 { - return 1 - } else { - return 0 - } -} - func TestCompareFunc(t *testing.T) { intWant := func(want bool) string { if want { @@ -261,19 +304,19 @@ func TestCompareFunc(t *testing.T) { } for _, test := range compareIntTests { - if got := CompareFunc(test.s1, test.s2, cmp[int]); got != test.want { + if got := CompareFunc(test.s1, test.s2, cmpCompare[int]); got != test.want { t.Errorf("CompareFunc(%v, %v, cmp[int]) = %d, want %d", test.s1, test.s2, got, test.want) } } for _, test := range compareFloatTests { - if got := CompareFunc(test.s1, test.s2, cmp[float64]); got != test.want { + if got := CompareFunc(test.s1, test.s2, cmpCompare[float64]); got != test.want { t.Errorf("CompareFunc(%v, %v, cmp[float64]) = %d, want %d", test.s1, test.s2, got, test.want) } } s1 := []int{1, 2, 3} s2 := []int{2, 3, 4} - if got := CompareFunc(s1, s2, equalToCmp(offByOne[int])); got != 0 { + if got := CompareFunc(s1, s2, equalToCmp(offByOne)); got != 0 { t.Errorf("CompareFunc(%v, %v, offByOne) = %d, want 0", s1, s2, got) } @@ -450,6 +493,45 @@ func TestInsert(t *testing.T) { t.Errorf("Insert(%v, %d, %v...) = %v, want %v", test.s, test.i, test.add, got, test.want) } } + + if !raceEnabled { + // Allocations should be amortized. + const count = 50 + n := testing.AllocsPerRun(10, func() { + s := []int{1, 2, 3} + for i := 0; i < count; i++ { + s = Insert(s, 0, 1) + } + }) + if n > count/2 { + t.Errorf("too many allocations inserting %d elements: got %v, want less than %d", count, n, count/2) + } + } +} + +func TestInsertOverlap(t *testing.T) { + const N = 10 + a := make([]int, N) + want := make([]int, 2*N) + for n := 0; n <= N; n++ { // length + for i := 0; i <= n; i++ { // insertion point + for x := 0; x <= N; x++ { // start of inserted data + for y := x; y <= N; y++ { // end of inserted data + for k := 0; k < N; k++ { + a[k] = k + } + want = want[:0] + want = append(want, a[:i]...) + want = append(want, a[x:y]...) + want = append(want, a[i:n]...) + got := Insert(a[:n], i, a[x:y]...) + if !Equal(got, want) { + t.Errorf("Insert with overlap failed n=%d i=%d x=%d y=%d, got %v want %v", n, i, x, y, got, want) + } + } + } + } + } } var deleteTests = []struct { @@ -748,6 +830,34 @@ func TestClip(t *testing.T) { } } +func TestReverse(t *testing.T) { + even := []int{3, 1, 4, 1, 5, 9} // len = 6 + Reverse(even) + if want := []int{9, 5, 1, 4, 1, 3}; !Equal(even, want) { + t.Errorf("Reverse(even) = %v, want %v", even, want) + } + + odd := []int{3, 1, 4, 1, 5, 9, 2} // len = 7 + Reverse(odd) + if want := []int{2, 9, 5, 1, 4, 1, 3}; !Equal(odd, want) { + t.Errorf("Reverse(odd) = %v, want %v", odd, want) + } + + words := strings.Fields("one two three") + Reverse(words) + if want := strings.Fields("three two one"); !Equal(words, want) { + t.Errorf("Reverse(words) = %v, want %v", words, want) + } + + singleton := []string{"one"} + Reverse(singleton) + if want := []string{"one"}; !Equal(singleton, want) { + t.Errorf("Reverse(singeleton) = %v, want %v", singleton, want) + } + + Reverse[[]string](nil) +} + // naiveReplace is a baseline implementation to the Replace function. func naiveReplace[S ~[]E, E any](s S, i, j int, v ...E) S { s = Delete(s, i, j) @@ -812,6 +922,33 @@ func TestReplacePanics(t *testing.T) { } } +func TestReplaceOverlap(t *testing.T) { + const N = 10 + a := make([]int, N) + want := make([]int, 2*N) + for n := 0; n <= N; n++ { // length + for i := 0; i <= n; i++ { // insertion point 1 + for j := i; j <= n; j++ { // insertion point 2 + for x := 0; x <= N; x++ { // start of inserted data + for y := x; y <= N; y++ { // end of inserted data + for k := 0; k < N; k++ { + a[k] = k + } + want = want[:0] + want = append(want, a[:i]...) + want = append(want, a[x:y]...) + want = append(want, a[j:n]...) + got := Replace(a[:n], i, j, a[x:y]...) + if !Equal(got, want) { + t.Errorf("Insert with overlap failed n=%d i=%d j=%d x=%d y=%d, got %v want %v", n, i, j, x, y, got, want) + } + } + } + } + } + } +} + func BenchmarkReplace(b *testing.B) { cases := []struct { name string @@ -860,3 +997,58 @@ func BenchmarkReplace(b *testing.B) { } } + +func TestRotate(t *testing.T) { + const N = 10 + s := make([]int, 0, N) + for n := 0; n < N; n++ { + for r := 0; r < n; r++ { + s = s[:0] + for i := 0; i < n; i++ { + s = append(s, i) + } + rotateLeft(s, r) + for i := 0; i < n; i++ { + if s[i] != (i+r)%n { + t.Errorf("expected n=%d r=%d i:%d want:%d got:%d", n, r, i, (i+r)%n, s[i]) + } + } + } + } +} + +func TestInsertGrowthRate(t *testing.T) { + b := make([]byte, 1) + maxCap := cap(b) + nGrow := 0 + const N = 1e6 + for i := 0; i < N; i++ { + b = Insert(b, len(b)-1, 0) + if cap(b) > maxCap { + maxCap = cap(b) + nGrow++ + } + } + want := int(math.Log(N) / math.Log(1.25)) // 1.25 == growth rate for large slices + if nGrow > want { + t.Errorf("too many grows. got:%d want:%d", nGrow, want) + } +} + +func TestReplaceGrowthRate(t *testing.T) { + b := make([]byte, 2) + maxCap := cap(b) + nGrow := 0 + const N = 1e6 + for i := 0; i < N; i++ { + b = Replace(b, len(b)-2, len(b)-1, 0, 0) + if cap(b) > maxCap { + maxCap = cap(b) + nGrow++ + } + } + want := int(math.Log(N) / math.Log(1.25)) // 1.25 == growth rate for large slices + if nGrow > want { + t.Errorf("too many grows. got:%d want:%d", nGrow, want) + } +} diff --git a/slices/sort.go b/slices/sort.go index 231b6448a..b67897f76 100644 --- a/slices/sort.go +++ b/slices/sort.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:generate go run $GOROOT/src/sort/gen_sort_variants.go -exp + package slices import ( @@ -11,57 +13,116 @@ import ( ) // Sort sorts a slice of any ordered type in ascending order. -// Sort may fail to sort correctly when sorting slices of floating-point -// numbers containing Not-a-number (NaN) values. -// Use slices.SortFunc(x, func(a, b float64) bool {return a < b || (math.IsNaN(a) && !math.IsNaN(b))}) -// instead if the input may contain NaNs. -func Sort[E constraints.Ordered](x []E) { +// When sorting floating-point numbers, NaNs are ordered before other values. +func Sort[S ~[]E, E constraints.Ordered](x S) { n := len(x) pdqsortOrdered(x, 0, n, bits.Len(uint(n))) } -// SortFunc sorts the slice x in ascending order as determined by the less function. -// This sort is not guaranteed to be stable. +// SortFunc sorts the slice x in ascending order as determined by the cmp +// function. This sort is not guaranteed to be stable. +// cmp(a, b) should return a negative number when a < b, a positive number when +// a > b and zero when a == b. // -// SortFunc requires that less is a strict weak ordering. +// SortFunc requires that cmp is a strict weak ordering. // See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. -func SortFunc[E any](x []E, less func(a, b E) bool) { +func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { n := len(x) - pdqsortLessFunc(x, 0, n, bits.Len(uint(n)), less) + pdqsortCmpFunc(x, 0, n, bits.Len(uint(n)), cmp) } // SortStableFunc sorts the slice x while keeping the original order of equal -// elements, using less to compare elements. -func SortStableFunc[E any](x []E, less func(a, b E) bool) { - stableLessFunc(x, len(x), less) +// elements, using cmp to compare elements in the same way as [SortFunc]. +func SortStableFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { + stableCmpFunc(x, len(x), cmp) } // IsSorted reports whether x is sorted in ascending order. -func IsSorted[E constraints.Ordered](x []E) bool { +func IsSorted[S ~[]E, E constraints.Ordered](x S) bool { for i := len(x) - 1; i > 0; i-- { - if x[i] < x[i-1] { + if cmpLess(x[i], x[i-1]) { return false } } return true } -// IsSortedFunc reports whether x is sorted in ascending order, with less as the -// comparison function. -func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool { +// IsSortedFunc reports whether x is sorted in ascending order, with cmp as the +// comparison function as defined by [SortFunc]. +func IsSortedFunc[S ~[]E, E any](x S, cmp func(a, b E) int) bool { for i := len(x) - 1; i > 0; i-- { - if less(x[i], x[i-1]) { + if cmp(x[i], x[i-1]) < 0 { return false } } return true } +// Min returns the minimal value in x. It panics if x is empty. +// For floating-point numbers, Min propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Min[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Min: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = min(m, x[i]) + } + return m +} + +// MinFunc returns the minimal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one minimal element +// according to the cmp function, MinFunc returns the first one. +func MinFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MinFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) < 0 { + m = x[i] + } + } + return m +} + +// Max returns the maximal value in x. It panics if x is empty. +// For floating-point E, Max propagates NaNs (any NaN value in x +// forces the output to be NaN). +func Max[S ~[]E, E constraints.Ordered](x S) E { + if len(x) < 1 { + panic("slices.Max: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + m = max(m, x[i]) + } + return m +} + +// MaxFunc returns the maximal value in x, using cmp to compare elements. +// It panics if x is empty. If there is more than one maximal element +// according to the cmp function, MaxFunc returns the first one. +func MaxFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { + if len(x) < 1 { + panic("slices.MaxFunc: empty list") + } + m := x[0] + for i := 1; i < len(x); i++ { + if cmp(x[i], m) > 0 { + m = x[i] + } + } + return m +} + // BinarySearch searches for target in a sorted slice and returns the position // where target is found, or the position where target would appear in the // sort order; it also returns a bool saying whether the target is really found // in the slice. The slice must be sorted in increasing order. -func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { +func BinarySearch[S ~[]E, E constraints.Ordered](x S, target E) (int, bool) { // Inlining is faster than calling BinarySearchFunc with a lambda. n := len(x) // Define x[-1] < target and x[n] >= target. @@ -70,24 +131,24 @@ func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) { for i < j { h := int(uint(i+j) >> 1) // avoid overflow when computing h // i ≤ h < j - if x[h] < target { + if cmpLess(x[h], target) { i = h + 1 // preserves x[i-1] < target } else { j = h // preserves x[j] >= target } } // i == j, x[i-1] < target, and x[j] (= x[i]) >= target => answer is i. - return i, i < n && x[i] == target + return i, i < n && (x[i] == target || (isNaN(x[i]) && isNaN(target))) } -// BinarySearchFunc works like BinarySearch, but uses a custom comparison +// BinarySearchFunc works like [BinarySearch], but uses a custom comparison // function. The slice must be sorted in increasing order, where "increasing" // is defined by cmp. cmp should return 0 if the slice element matches // the target, a negative number if the slice element precedes the target, // or a positive number if the slice element follows the target. // cmp must implement the same ordering as the slice, such that if // cmp(a, t) < 0 and cmp(b, t) >= 0, then a must precede b in the slice. -func BinarySearchFunc[E, T any](x []E, target T, cmp func(E, T) int) (int, bool) { +func BinarySearchFunc[S ~[]E, E, T any](x S, target T, cmp func(E, T) int) (int, bool) { n := len(x) // Define cmp(x[-1], target) < 0 and cmp(x[n], target) >= 0 . // Invariant: cmp(x[i - 1], target) < 0, cmp(x[j], target) >= 0. @@ -126,3 +187,9 @@ func (r *xorshift) Next() uint64 { func nextPowerOfTwo(length int) uint { return 1 << bits.Len(uint(length)) } + +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T constraints.Ordered](x T) bool { + return x != x +} diff --git a/slices/sort_benchmark_test.go b/slices/sort_benchmark_test.go index ee49f66bc..6aa03c81d 100644 --- a/slices/sort_benchmark_test.go +++ b/slices/sort_benchmark_test.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "sort" + "strconv" "strings" "testing" ) @@ -50,6 +51,15 @@ func BenchmarkSortInts(b *testing.B) { } } +func makeSortedStrings(n int) []string { + x := make([]string, n) + for i := 0; i < n; i++ { + x[i] = strconv.Itoa(i) + } + Sort(x) + return x +} + func BenchmarkSlicesSortInts(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() @@ -77,6 +87,24 @@ func BenchmarkSlicesSortInts_Reversed(b *testing.B) { } } +func BenchmarkIntsAreSorted(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + ints := makeSortedInts(N) + b.StartTimer() + sort.IntsAreSorted(ints) + } +} + +func BenchmarkIsSorted(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + ints := makeSortedInts(N) + b.StartTimer() + IsSorted(ints) + } +} + // Since we're benchmarking these sorts against each other, make sure that they // generate similar results. func TestIntSorts(t *testing.T) { @@ -96,7 +124,7 @@ func TestIntSorts(t *testing.T) { // The following is a benchmark for sorting strings. // makeRandomStrings generates n random strings with alphabetic runes of -// varying lenghts. +// varying lengths. func makeRandomStrings(n int) []string { rand.Seed(42) var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -135,6 +163,15 @@ func BenchmarkSortStrings(b *testing.B) { } } +func BenchmarkSortStrings_Sorted(b *testing.B) { + ss := makeSortedStrings(N) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sort.Strings(ss) + } +} + func BenchmarkSlicesSortStrings(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() @@ -144,6 +181,15 @@ func BenchmarkSlicesSortStrings(b *testing.B) { } } +func BenchmarkSlicesSortStrings_Sorted(b *testing.B) { + ss := makeSortedStrings(N) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + Sort(ss) + } +} + // These benchmarks compare sorting a slice of structs with sort.Sort vs. // slices.SortFunc. type myStruct struct { @@ -174,7 +220,7 @@ func TestStructSorts(t *testing.T) { } sort.Sort(ss) - SortFunc(ss2, func(a, b *myStruct) bool { return a.n < b.n }) + SortFunc(ss2, func(a, b *myStruct) int { return a.n - b.n }) for i := range ss { if *ss[i] != *ss2[i] { @@ -193,12 +239,12 @@ func BenchmarkSortStructs(b *testing.B) { } func BenchmarkSortFuncStructs(b *testing.B) { - lessFunc := func(a, b *myStruct) bool { return a.n < b.n } + cmpFunc := func(a, b *myStruct) int { return a.n - b.n } for i := 0; i < b.N; i++ { b.StopTimer() ss := makeRandomStructs(N) b.StartTimer() - SortFunc(ss, lessFunc) + SortFunc(ss, cmpFunc) } } diff --git a/slices/sort_test.go b/slices/sort_test.go index 3befa3e7e..501b5ee6d 100644 --- a/slices/sort_test.go +++ b/slices/sort_test.go @@ -5,6 +5,7 @@ package slices import ( + "fmt" "math" "math/rand" "sort" @@ -29,7 +30,7 @@ func TestSortIntSlice(t *testing.T) { func TestSortFuncIntSlice(t *testing.T) { data := Clone(ints[:]) - SortFunc(data, func(a, b int) bool { return a < b }) + SortFunc(data, func(a, b int) int { return a - b }) if !IsSorted(data) { t.Errorf("sorted %v", ints) t.Errorf(" got %v", data) @@ -47,17 +48,18 @@ func TestSortFloat64Slice(t *testing.T) { func TestSortFloat64SliceWithNaNs(t *testing.T) { data := float64sWithNaNs[:] - input := Clone(data) + data2 := Clone(data) - // Make sure Sort doesn't panic when the slice contains NaNs. Sort(data) - // Check whether the result is a permutation of the input. - sort.Float64s(data) - sort.Float64s(input) - for i, v := range input { - if data[i] != v && !(math.IsNaN(data[i]) && math.IsNaN(v)) { - t.Fatalf("the result is not a permutation of the input\ngot %v\nwant %v", data, input) - } + sort.Float64s(data2) + + if !IsSorted(data) { + t.Error("IsSorted indicates data isn't sorted") + } + + // Compare for equality using cmp.Compare, which considers NaNs equal. + if !EqualFunc(data, data2, func(a, b float64) bool { return cmpCompare(a, b) == 0 }) { + t.Errorf("mismatch between Sort and sort.Float64: got %v, want %v", data, data2) } } @@ -95,8 +97,8 @@ type intPair struct { type intPairs []intPair // Pairs compare on a only. -func intPairLess(x, y intPair) bool { - return x.a < y.a +func intPairCmp(x, y intPair) int { + return x.a - y.a } // Record initial order in B. @@ -134,12 +136,12 @@ func TestStability(t *testing.T) { for i := 0; i < len(data); i++ { data[i].a = rand.Intn(m) } - if IsSortedFunc(data, intPairLess) { + if IsSortedFunc(data, intPairCmp) { t.Fatalf("terrible rand.rand") } data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable didn't sort %d ints", n) } if !data.inOrder() { @@ -148,8 +150,8 @@ func TestStability(t *testing.T) { // already sorted data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable shuffled sorted %d ints (order)", n) } if !data.inOrder() { @@ -161,8 +163,8 @@ func TestStability(t *testing.T) { data[i].a = len(data) - i } data.initB() - SortStableFunc(data, intPairLess) - if !IsSortedFunc(data, intPairLess) { + SortStableFunc(data, intPairCmp) + if !IsSortedFunc(data, intPairCmp) { t.Errorf("Stable didn't sort %d ints", n) } if !data.inOrder() { @@ -170,6 +172,125 @@ func TestStability(t *testing.T) { } } +type S struct { + a int + b string +} + +func cmpS(s1, s2 S) int { + return cmpCompare(s1.a, s2.a) +} + +func TestMinMax(t *testing.T) { + intCmp := func(a, b int) int { return a - b } + + tests := []struct { + data []int + wantMin int + wantMax int + }{ + {[]int{7}, 7, 7}, + {[]int{1, 2}, 1, 2}, + {[]int{2, 1}, 1, 2}, + {[]int{1, 2, 3}, 1, 3}, + {[]int{3, 2, 1}, 1, 3}, + {[]int{2, 1, 3}, 1, 3}, + {[]int{2, 2, 3}, 2, 3}, + {[]int{3, 2, 3}, 2, 3}, + {[]int{0, 2, -9}, -9, 2}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) { + gotMin := Min(tt.data) + if gotMin != tt.wantMin { + t.Errorf("Min got %v, want %v", gotMin, tt.wantMin) + } + + gotMinFunc := MinFunc(tt.data, intCmp) + if gotMinFunc != tt.wantMin { + t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin) + } + + gotMax := Max(tt.data) + if gotMax != tt.wantMax { + t.Errorf("Max got %v, want %v", gotMax, tt.wantMax) + } + + gotMaxFunc := MaxFunc(tt.data, intCmp) + if gotMaxFunc != tt.wantMax { + t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax) + } + }) + } + + svals := []S{ + {1, "a"}, + {2, "a"}, + {1, "b"}, + {2, "b"}, + } + + gotMin := MinFunc(svals, cmpS) + wantMin := S{1, "a"} + if gotMin != wantMin { + t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin) + } + + gotMax := MaxFunc(svals, cmpS) + wantMax := S{2, "a"} + if gotMax != wantMax { + t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax) + } +} + +func TestMinMaxNaNs(t *testing.T) { + fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14} + if Min(fs) != -400.4 { + t.Errorf("got min %v, want -400.4", Min(fs)) + } + if Max(fs) != 999.9 { + t.Errorf("got max %v, want 999.9", Max(fs)) + } + + // No matter which element of fs is replaced with a NaN, both Min and Max + // should propagate the NaN to their output. + for i := 0; i < len(fs); i++ { + testfs := Clone(fs) + testfs[i] = math.NaN() + + fmin := Min(testfs) + if !math.IsNaN(fmin) { + t.Errorf("got min %v, want NaN", fmin) + } + + fmax := Max(testfs) + if !math.IsNaN(fmax) { + t.Errorf("got max %v, want NaN", fmax) + } + } +} + +func TestMinMaxPanics(t *testing.T) { + intCmp := func(a, b int) int { return a - b } + emptySlice := []int{} + + if !panics(func() { Min(emptySlice) }) { + t.Errorf("Min([]): got no panic, want panic") + } + + if !panics(func() { Max(emptySlice) }) { + t.Errorf("Max([]): got no panic, want panic") + } + + if !panics(func() { MinFunc(emptySlice, intCmp) }) { + t.Errorf("MinFunc([]): got no panic, want panic") + } + + if !panics(func() { MaxFunc(emptySlice, intCmp) }) { + t.Errorf("MaxFunc([]): got no panic, want panic") + } +} + func TestBinarySearch(t *testing.T) { str1 := []string{"foo"} str2 := []string{"ab", "ca"} @@ -282,6 +403,32 @@ func TestBinarySearchInts(t *testing.T) { } } +func TestBinarySearchFloats(t *testing.T) { + data := []float64{math.NaN(), -0.25, 0.0, 1.4} + tests := []struct { + target float64 + wantPos int + wantFound bool + }{ + {math.NaN(), 0, true}, + {math.Inf(-1), 1, false}, + {-0.25, 1, true}, + {0.0, 2, true}, + {1.4, 3, true}, + {1.5, 4, false}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) { + { + pos, found := BinarySearch(data, tt.target) + if pos != tt.wantPos || found != tt.wantFound { + t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) + } + } + }) + } +} + func TestBinarySearchFunc(t *testing.T) { data := []int{1, 10, 11, 2} // sorted lexicographically cmp := func(a int, b string) int { diff --git a/slices/zsortfunc.go b/slices/zsortanyfunc.go similarity index 64% rename from slices/zsortfunc.go rename to slices/zsortanyfunc.go index 2a632476c..06f2c7a24 100644 --- a/slices/zsortfunc.go +++ b/slices/zsortanyfunc.go @@ -6,28 +6,28 @@ package slices -// insertionSortLessFunc sorts data[a:b] using insertion sort. -func insertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +// insertionSortCmpFunc sorts data[a:b] using insertion sort. +func insertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { for i := a + 1; i < b; i++ { - for j := i; j > a && less(data[j], data[j-1]); j-- { + for j := i; j > a && (cmp(data[j], data[j-1]) < 0); j-- { data[j], data[j-1] = data[j-1], data[j] } } } -// siftDownLessFunc implements the heap property on data[lo:hi]. +// siftDownCmpFunc implements the heap property on data[lo:hi]. // first is an offset into the array where the root of the heap lies. -func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool) { +func siftDownCmpFunc[E any](data []E, lo, hi, first int, cmp func(a, b E) int) { root := lo for { child := 2*root + 1 if child >= hi { break } - if child+1 < hi && less(data[first+child], data[first+child+1]) { + if child+1 < hi && (cmp(data[first+child], data[first+child+1]) < 0) { child++ } - if !less(data[first+root], data[first+child]) { + if !(cmp(data[first+root], data[first+child]) < 0) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -35,30 +35,30 @@ func siftDownLessFunc[E any](data []E, lo, hi, first int, less func(a, b E) bool } } -func heapSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func heapSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { first := a lo := 0 hi := b - a // Build heap with greatest element at top. for i := (hi - 1) / 2; i >= 0; i-- { - siftDownLessFunc(data, i, hi, first, less) + siftDownCmpFunc(data, i, hi, first, cmp) } // Pop elements, largest first, into end of data. for i := hi - 1; i >= 0; i-- { data[first], data[first+i] = data[first+i], data[first] - siftDownLessFunc(data, lo, i, first, less) + siftDownCmpFunc(data, lo, i, first, cmp) } } -// pdqsortLessFunc sorts data[a:b]. +// pdqsortCmpFunc sorts data[a:b]. // The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort. // pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf // C++ implementation: https://github.com/orlp/pdqsort // Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/ // limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort. -func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { +func pdqsortCmpFunc[E any](data []E, a, b, limit int, cmp func(a, b E) int) { const maxInsertion = 12 var ( @@ -70,25 +70,25 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { length := b - a if length <= maxInsertion { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) return } // Fall back to heapsort if too many bad choices were made. if limit == 0 { - heapSortLessFunc(data, a, b, less) + heapSortCmpFunc(data, a, b, cmp) return } // If the last partitioning was imbalanced, we need to breaking patterns. if !wasBalanced { - breakPatternsLessFunc(data, a, b, less) + breakPatternsCmpFunc(data, a, b, cmp) limit-- } - pivot, hint := choosePivotLessFunc(data, a, b, less) + pivot, hint := choosePivotCmpFunc(data, a, b, cmp) if hint == decreasingHint { - reverseRangeLessFunc(data, a, b, less) + reverseRangeCmpFunc(data, a, b, cmp) // The chosen pivot was pivot-a elements after the start of the array. // After reversing it is pivot-a elements before the end of the array. // The idea came from Rust's implementation. @@ -98,48 +98,48 @@ func pdqsortLessFunc[E any](data []E, a, b, limit int, less func(a, b E) bool) { // The slice is likely already sorted. if wasBalanced && wasPartitioned && hint == increasingHint { - if partialInsertionSortLessFunc(data, a, b, less) { + if partialInsertionSortCmpFunc(data, a, b, cmp) { return } } // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !less(data[a-1], data[pivot]) { - mid := partitionEqualLessFunc(data, a, b, pivot, less) + if a > 0 && !(cmp(data[a-1], data[pivot]) < 0) { + mid := partitionEqualCmpFunc(data, a, b, pivot, cmp) a = mid continue } - mid, alreadyPartitioned := partitionLessFunc(data, a, b, pivot, less) + mid, alreadyPartitioned := partitionCmpFunc(data, a, b, pivot, cmp) wasPartitioned = alreadyPartitioned leftLen, rightLen := mid-a, b-mid balanceThreshold := length / 8 if leftLen < rightLen { wasBalanced = leftLen >= balanceThreshold - pdqsortLessFunc(data, a, mid, limit, less) + pdqsortCmpFunc(data, a, mid, limit, cmp) a = mid + 1 } else { wasBalanced = rightLen >= balanceThreshold - pdqsortLessFunc(data, mid+1, b, limit, less) + pdqsortCmpFunc(data, mid+1, b, limit, cmp) b = mid } } } -// partitionLessFunc does one quicksort partition. +// partitionCmpFunc does one quicksort partition. // Let p = data[pivot] // Moves elements in data[a:b] around, so that data[i]

=p for inewpivot. // On return, data[newpivot] = p -func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int, alreadyPartitioned bool) { +func partitionCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int, alreadyPartitioned bool) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -151,10 +151,10 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) j-- for { - for i <= j && less(data[i], data[a]) { + for i <= j && (cmp(data[i], data[a]) < 0) { i++ } - for i <= j && !less(data[j], data[a]) { + for i <= j && !(cmp(data[j], data[a]) < 0) { j-- } if i > j { @@ -168,17 +168,17 @@ func partitionLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) return j, false } -// partitionEqualLessFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. +// partitionEqualCmpFunc partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot]. // It assumed that data[a:b] does not contain elements smaller than the data[pivot]. -func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) bool) (newpivot int) { +func partitionEqualCmpFunc[E any](data []E, a, b, pivot int, cmp func(a, b E) int) (newpivot int) { data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !less(data[a], data[i]) { + for i <= j && !(cmp(data[a], data[i]) < 0) { i++ } - for i <= j && less(data[a], data[j]) { + for i <= j && (cmp(data[a], data[j]) < 0) { j-- } if i > j { @@ -191,15 +191,15 @@ func partitionEqualLessFunc[E any](data []E, a, b, pivot int, less func(a, b E) return i } -// partialInsertionSortLessFunc partially sorts a slice, returns true if the slice is sorted at the end. -func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) bool) bool { +// partialInsertionSortCmpFunc partially sorts a slice, returns true if the slice is sorted at the end. +func partialInsertionSortCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) bool { const ( maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted shortestShifting = 50 // don't shift any elements on short arrays ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !less(data[i], data[i-1]) { + for i < b && !(cmp(data[i], data[i-1]) < 0) { i++ } @@ -216,7 +216,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -225,7 +225,7 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !less(data[j], data[j-1]) { + if !(cmp(data[j], data[j-1]) < 0) { break } data[j], data[j-1] = data[j-1], data[j] @@ -235,9 +235,9 @@ func partialInsertionSortLessFunc[E any](data []E, a, b int, less func(a, b E) b return false } -// breakPatternsLessFunc scatters some elements around in an attempt to break some patterns +// breakPatternsCmpFunc scatters some elements around in an attempt to break some patterns // that might cause imbalanced partitions in quicksort. -func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func breakPatternsCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { length := b - a if length >= 8 { random := xorshift(length) @@ -253,12 +253,12 @@ func breakPatternsLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -// choosePivotLessFunc chooses a pivot in data[a:b]. +// choosePivotCmpFunc chooses a pivot in data[a:b]. // // [0,8): chooses a static pivot. // [8,shortestNinther): uses the simple median-of-three method. // [shortestNinther,∞): uses the Tukey ninther method. -func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (pivot int, hint sortedHint) { +func choosePivotCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) (pivot int, hint sortedHint) { const ( shortestNinther = 50 maxSwaps = 4 * 3 @@ -276,12 +276,12 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv if l >= 8 { if l >= shortestNinther { // Tukey ninther method, the idea came from Rust's implementation. - i = medianAdjacentLessFunc(data, i, &swaps, less) - j = medianAdjacentLessFunc(data, j, &swaps, less) - k = medianAdjacentLessFunc(data, k, &swaps, less) + i = medianAdjacentCmpFunc(data, i, &swaps, cmp) + j = medianAdjacentCmpFunc(data, j, &swaps, cmp) + k = medianAdjacentCmpFunc(data, k, &swaps, cmp) } // Find the median among i, j, k and stores it into j. - j = medianLessFunc(data, i, j, k, &swaps, less) + j = medianCmpFunc(data, i, j, k, &swaps, cmp) } switch swaps { @@ -294,29 +294,29 @@ func choosePivotLessFunc[E any](data []E, a, b int, less func(a, b E) bool) (piv } } -// order2LessFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. -func order2LessFunc[E any](data []E, a, b int, swaps *int, less func(a, b E) bool) (int, int) { - if less(data[b], data[a]) { +// order2CmpFunc returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. +func order2CmpFunc[E any](data []E, a, b int, swaps *int, cmp func(a, b E) int) (int, int) { + if cmp(data[b], data[a]) < 0 { *swaps++ return b, a } return a, b } -// medianLessFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. -func medianLessFunc[E any](data []E, a, b, c int, swaps *int, less func(a, b E) bool) int { - a, b = order2LessFunc(data, a, b, swaps, less) - b, c = order2LessFunc(data, b, c, swaps, less) - a, b = order2LessFunc(data, a, b, swaps, less) +// medianCmpFunc returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c. +func medianCmpFunc[E any](data []E, a, b, c int, swaps *int, cmp func(a, b E) int) int { + a, b = order2CmpFunc(data, a, b, swaps, cmp) + b, c = order2CmpFunc(data, b, c, swaps, cmp) + a, b = order2CmpFunc(data, a, b, swaps, cmp) return b } -// medianAdjacentLessFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. -func medianAdjacentLessFunc[E any](data []E, a int, swaps *int, less func(a, b E) bool) int { - return medianLessFunc(data, a-1, a, a+1, swaps, less) +// medianAdjacentCmpFunc finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a. +func medianAdjacentCmpFunc[E any](data []E, a int, swaps *int, cmp func(a, b E) int) int { + return medianCmpFunc(data, a-1, a, a+1, swaps, cmp) } -func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { +func reverseRangeCmpFunc[E any](data []E, a, b int, cmp func(a, b E) int) { i := a j := b - 1 for i < j { @@ -326,37 +326,37 @@ func reverseRangeLessFunc[E any](data []E, a, b int, less func(a, b E) bool) { } } -func swapRangeLessFunc[E any](data []E, a, b, n int, less func(a, b E) bool) { +func swapRangeCmpFunc[E any](data []E, a, b, n int, cmp func(a, b E) int) { for i := 0; i < n; i++ { data[a+i], data[b+i] = data[b+i], data[a+i] } } -func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { +func stableCmpFunc[E any](data []E, n int, cmp func(a, b E) int) { blockSize := 20 // must be > 0 a, b := 0, blockSize for b <= n { - insertionSortLessFunc(data, a, b, less) + insertionSortCmpFunc(data, a, b, cmp) a = b b += blockSize } - insertionSortLessFunc(data, a, n, less) + insertionSortCmpFunc(data, a, n, cmp) for blockSize < n { a, b = 0, 2*blockSize for b <= n { - symMergeLessFunc(data, a, a+blockSize, b, less) + symMergeCmpFunc(data, a, a+blockSize, b, cmp) a = b b += 2 * blockSize } if m := a + blockSize; m < n { - symMergeLessFunc(data, a, m, n, less) + symMergeCmpFunc(data, a, m, n, cmp) } blockSize *= 2 } } -// symMergeLessFunc merges the two sorted subsequences data[a:m] and data[m:b] using +// symMergeCmpFunc merges the two sorted subsequences data[a:m] and data[m:b] using // the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum // Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz // Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in @@ -375,7 +375,7 @@ func stableLessFunc[E any](data []E, n int, less func(a, b E) bool) { // symMerge assumes non-degenerate arguments: a < m && m < b. // Having the caller check this condition eliminates many leaf recursion calls, // which improves performance. -func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func symMergeCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { // Avoid unnecessary recursions of symMerge // by direct insertion of data[a] into data[m:b] // if data[a:m] only contains one element. @@ -387,7 +387,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := b for i < j { h := int(uint(i+j) >> 1) - if less(data[h], data[a]) { + if cmp(data[h], data[a]) < 0 { i = h + 1 } else { j = h @@ -411,7 +411,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { j := m for i < j { h := int(uint(i+j) >> 1) - if !less(data[m], data[h]) { + if !(cmp(data[m], data[h]) < 0) { i = h + 1 } else { j = h @@ -438,7 +438,7 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { for start < r { c := int(uint(start+r) >> 1) - if !less(data[p-c], data[c]) { + if !(cmp(data[p-c], data[c]) < 0) { start = c + 1 } else { r = c @@ -447,33 +447,33 @@ func symMergeLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { end := n - start if start < m && m < end { - rotateLessFunc(data, start, m, end, less) + rotateCmpFunc(data, start, m, end, cmp) } if a < start && start < mid { - symMergeLessFunc(data, a, start, mid, less) + symMergeCmpFunc(data, a, start, mid, cmp) } if mid < end && end < b { - symMergeLessFunc(data, mid, end, b, less) + symMergeCmpFunc(data, mid, end, b, cmp) } } -// rotateLessFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: +// rotateCmpFunc rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data: // Data of the form 'x u v y' is changed to 'x v u y'. // rotate performs at most b-a many calls to data.Swap, // and it assumes non-degenerate arguments: a < m && m < b. -func rotateLessFunc[E any](data []E, a, m, b int, less func(a, b E) bool) { +func rotateCmpFunc[E any](data []E, a, m, b int, cmp func(a, b E) int) { i := m - a j := b - m for i != j { if i > j { - swapRangeLessFunc(data, m-i, m, j, less) + swapRangeCmpFunc(data, m-i, m, j, cmp) i -= j } else { - swapRangeLessFunc(data, m-i, m+j-i, i, less) + swapRangeCmpFunc(data, m-i, m+j-i, i, cmp) j -= i } } // i == j - swapRangeLessFunc(data, m-i, m, i, less) + swapRangeCmpFunc(data, m-i, m, i, cmp) } diff --git a/slices/zsortordered.go b/slices/zsortordered.go index efaa1c8b7..99b47c398 100644 --- a/slices/zsortordered.go +++ b/slices/zsortordered.go @@ -11,7 +11,7 @@ import "golang.org/x/exp/constraints" // insertionSortOrdered sorts data[a:b] using insertion sort. func insertionSortOrdered[E constraints.Ordered](data []E, a, b int) { for i := a + 1; i < b; i++ { - for j := i; j > a && (data[j] < data[j-1]); j-- { + for j := i; j > a && cmpLess(data[j], data[j-1]); j-- { data[j], data[j-1] = data[j-1], data[j] } } @@ -26,10 +26,10 @@ func siftDownOrdered[E constraints.Ordered](data []E, lo, hi, first int) { if child >= hi { break } - if child+1 < hi && (data[first+child] < data[first+child+1]) { + if child+1 < hi && cmpLess(data[first+child], data[first+child+1]) { child++ } - if !(data[first+root] < data[first+child]) { + if !cmpLess(data[first+root], data[first+child]) { return } data[first+root], data[first+child] = data[first+child], data[first+root] @@ -107,7 +107,7 @@ func pdqsortOrdered[E constraints.Ordered](data []E, a, b, limit int) { // Probably the slice contains many duplicate elements, partition the slice into // elements equal to and elements greater than the pivot. - if a > 0 && !(data[a-1] < data[pivot]) { + if a > 0 && !cmpLess(data[a-1], data[pivot]) { mid := partitionEqualOrdered(data, a, b, pivot) a = mid continue @@ -138,10 +138,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo data[a], data[pivot] = data[pivot], data[a] i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -153,10 +153,10 @@ func partitionOrdered[E constraints.Ordered](data []E, a, b, pivot int) (newpivo j-- for { - for i <= j && (data[i] < data[a]) { + for i <= j && cmpLess(data[i], data[a]) { i++ } - for i <= j && !(data[j] < data[a]) { + for i <= j && !cmpLess(data[j], data[a]) { j-- } if i > j { @@ -177,10 +177,10 @@ func partitionEqualOrdered[E constraints.Ordered](data []E, a, b, pivot int) (ne i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned for { - for i <= j && !(data[a] < data[i]) { + for i <= j && !cmpLess(data[a], data[i]) { i++ } - for i <= j && (data[a] < data[j]) { + for i <= j && cmpLess(data[a], data[j]) { j-- } if i > j { @@ -201,7 +201,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool ) i := a + 1 for j := 0; j < maxSteps; j++ { - for i < b && !(data[i] < data[i-1]) { + for i < b && !cmpLess(data[i], data[i-1]) { i++ } @@ -218,7 +218,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the smaller one to the left. if i-a >= 2 { for j := i - 1; j >= 1; j-- { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -227,7 +227,7 @@ func partialInsertionSortOrdered[E constraints.Ordered](data []E, a, b int) bool // Shift the greater one to the right. if b-i >= 2 { for j := i + 1; j < b; j++ { - if !(data[j] < data[j-1]) { + if !cmpLess(data[j], data[j-1]) { break } data[j], data[j-1] = data[j-1], data[j] @@ -298,7 +298,7 @@ func choosePivotOrdered[E constraints.Ordered](data []E, a, b int) (pivot int, h // order2Ordered returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a. func order2Ordered[E constraints.Ordered](data []E, a, b int, swaps *int) (int, int) { - if data[b] < data[a] { + if cmpLess(data[b], data[a]) { *swaps++ return b, a } @@ -389,7 +389,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := b for i < j { h := int(uint(i+j) >> 1) - if data[h] < data[a] { + if cmpLess(data[h], data[a]) { i = h + 1 } else { j = h @@ -413,7 +413,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { j := m for i < j { h := int(uint(i+j) >> 1) - if !(data[m] < data[h]) { + if !cmpLess(data[m], data[h]) { i = h + 1 } else { j = h @@ -440,7 +440,7 @@ func symMergeOrdered[E constraints.Ordered](data []E, a, m, b int) { for start < r { c := int(uint(start+r) >> 1) - if !(data[p-c] < data[c]) { + if !cmpLess(data[p-c], data[c]) { start = c + 1 } else { r = c