diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 2ffc5fdc..2e168604 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -1,4 +1,8 @@ # 开发中 +- [syncx: 支持分key加锁](https://github.com/ecodeclub/ekit/pull/224) +- [syncx: 添加具有最大申请次数限制的LimitPool](https://github.com/ecodeclub/ekit/pull/233) +- [tuple: 增加Pair的实现](https://github.com/ecodeclub/ekit/pull/237) +- [randx: 重构randx.RandCode的代码,增加对特殊字符的支持](https://github.com/ecodeclub/ekit/pull/241) # v0.0.8 - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) @@ -35,7 +39,9 @@ - [sqlx: Scanner 添加 NextResultSet 方法](https://github.com/ecodeclub/ekit/pull/212) - [ekit: AnyValue 支持As[Type]类型 String 转换](https://github.com/ecodeclub/ekit/pull/213) - [stringx: unsafe 转换 string 和 []byte](https://github.com/ecodeclub/ekit/pull/215) - - [stringx: 添加 Benchmark](https://github.com/ecodeclub/ekit/pull/216) +- [stringx: 添加 Benchmark](https://github.com/ecodeclub/ekit/pull/216) +- [tree: 把 internal 里的红黑树做一个简单封装](https://github.com/ecodeclub/ekit/pull/218) +- [queue: 把 internal 里的优先级队列做一个简单封装](https://github.com/ecodeclub/ekit/pull/218) # v0.0.7 - [slice: FilterDelete](https://github.com/ecodeclub/ekit/pull/152) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml deleted file mode 100644 index d08cba52..00000000 --- a/.github/workflows/changelog.yml +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2021 ecodeclub -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: changelog - -on: - pull_request: - types: [opened, synchronize, reopened, labeled, unlabeled] - branches: - - develop - - main - - dev - -jobs: - changelog: - runs-on: ubuntu-latest - if: "!contains(github.event.pull_request.labels.*.name, 'Skip Changelog')" - - steps: - - uses: actions/checkout@v2 - - - name: Check for CHANGELOG changes - run: | - # Only the latest commit of the feature branch is available - # automatically. To diff with the base branch, we need to - # fetch that too (and we only need its latest commit). - git fetch origin ${{ github.base_ref }} --depth=1 - if [[ $(git diff --name-only FETCH_HEAD | grep CHANGELOG) ]] - then - echo "A CHANGELOG was modified. Looks good!" - else - echo "No CHANGELOG was modified." - echo "Please add a CHANGELOG entry, or add the \"Skip Changelog\" label if not required." - false - fi \ No newline at end of file diff --git a/script/goimports.sh b/.script/goimports.sh similarity index 100% rename from script/goimports.sh rename to .script/goimports.sh diff --git a/script/setup.sh b/.script/setup.sh similarity index 100% rename from script/setup.sh rename to .script/setup.sh diff --git a/Makefile b/Makefile index c7e6056c..824788f7 100644 --- a/Makefile +++ b/Makefile @@ -8,11 +8,11 @@ ut: .PHONY: setup setup: - @sh ./script/setup.sh + @sh ./.script/setup.sh .PHONY: fmt fmt: - @sh ./script/goimports.sh + @sh ./.script/goimports.sh .PHONY: lint lint: diff --git a/README.md b/README.md index aecce8a5..7ea4737c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # ekit 泛型工具库。 -- [文档](https://ekit.gocn.vip/ekit/develop/guide/) +- [文档](https://doc.meoying.com/) ## 交流 diff --git a/go.mod b/go.mod index 80ea97db..bce2f00a 100644 --- a/go.mod +++ b/go.mod @@ -5,15 +5,16 @@ go 1.20 require ( github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/mattn/go-sqlite3 v1.14.15 - github.com/stretchr/testify v1.8.1 - golang.org/x/sync v0.1.0 + github.com/stretchr/testify v1.8.4 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d + golang.org/x/sync v0.4.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 92a1b932..010d4b3e 100644 --- a/go.sum +++ b/go.sum @@ -1,31 +1,31 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= +golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/errs/error.go b/internal/errs/error.go index a6f1e429..d05f0494 100644 --- a/internal/errs/error.go +++ b/internal/errs/error.go @@ -25,8 +25,8 @@ func NewErrIndexOutOfRange(length int, index int) error { } // NewErrInvalidType 创建一个代表类型转换失败的错误 -func NewErrInvalidType(want, got string) error { - return fmt.Errorf("ekit: 类型转换失败,want:%s, got:%s", want, got) +func NewErrInvalidType(want string, got any) error { + return fmt.Errorf("ekit: 类型转换失败,预期类型:%s, 实际值:%#v", want, got) } func NewErrInvalidIntervalValue(interval time.Duration) error { diff --git a/internal/list/skip_list.go b/internal/list/skip_list.go new file mode 100644 index 00000000..1167a721 --- /dev/null +++ b/internal/list/skip_list.go @@ -0,0 +1,174 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package list + +import ( + "errors" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/errs" + "golang.org/x/exp/rand" +) + +// 跳表 skip list + +const ( + FactorP = float32(0.25) // level i 上的结点 有FactorP的比例出现在level i + 1上 + MaxLevel = 32 +) + +// FactorP = 0.25, MaxLevel = 32 列表可包含 2^64 个元素 + +type skipListNode[T any] struct { + Val T + Forward []*skipListNode[T] +} + +type SkipList[T any] struct { + header *skipListNode[T] + level int // SkipList为空时, level为1 + compare ekit.Comparator[T] + size int +} + +func newSkipListNode[T any](Val T, level int) *skipListNode[T] { + return &skipListNode[T]{Val, make([]*skipListNode[T], level)} +} + +func (sl *SkipList[T]) AsSlice() []T { + curr := sl.header + slice := make([]T, 0, sl.size) + for curr.Forward[0] != nil { + slice = append(slice, curr.Forward[0].Val) + curr = curr.Forward[0] + } + return slice +} + +func NewSkipListFromSlice[T any](slice []T, compare ekit.Comparator[T]) *SkipList[T] { + sl := NewSkipList[T](compare) + for _, n := range slice { + sl.Insert(n) + } + return sl +} + +func NewSkipList[T any](compare ekit.Comparator[T]) *SkipList[T] { + return &SkipList[T]{ + header: &skipListNode[T]{ + Forward: make([]*skipListNode[T], MaxLevel), + }, + level: 1, + compare: compare, + } +} + +// levels的生成和跳表中元素个数无关 +func (sl *SkipList[T]) randomLevel() int { + level := 1 + p := FactorP + for (rand.Int31() & 0xFFFF) < int32(p*0xFFFF) { + level++ + } + if level < MaxLevel { + return level + } + return MaxLevel + +} + +func (sl *SkipList[T]) Search(target T) bool { + curr, _ := sl.traverse(target, sl.level) + curr = curr.Forward[0] // 第1层 包含所有元素 + return curr != nil && sl.compare(curr.Val, target) == 0 +} + +func (sl *SkipList[T]) traverse(Val T, level int) (*skipListNode[T], []*skipListNode[T]) { + update := make([]*skipListNode[T], MaxLevel) // update[i] 包含位于level i 的插入/删除位置左侧的指针 + curr := sl.header + for i := level - 1; i >= 0; i-- { + for curr.Forward[i] != nil && sl.compare(curr.Forward[i].Val, Val) < 0 { + curr = curr.Forward[i] + } + update[i] = curr + } + return curr, update +} + +func (sl *SkipList[T]) Insert(Val T) { + _, update := sl.traverse(Val, sl.level) + level := sl.randomLevel() + if level > sl.level { + for i := sl.level; i < level; i++ { + update[i] = sl.header + } + sl.level = level + } + + // 插入新节点 + newNode := newSkipListNode[T](Val, level) + for i := 0; i < level; i++ { + newNode.Forward[i] = update[i].Forward[i] + update[i].Forward[i] = newNode + } + + sl.size += 1 + +} + +func (sl *SkipList[T]) Len() int { + return sl.size +} + +func (sl *SkipList[T]) DeleteElement(target T) bool { + curr, update := sl.traverse(target, sl.level) + node := curr.Forward[0] + if node == nil || sl.compare(node.Val, target) != 0 { + return true + } + // 删除target结点 + for i := 0; i < sl.level && update[i].Forward[i] == node; i++ { + update[i].Forward[i] = node.Forward[i] + } + + // 更新层级 + for sl.level > 1 && sl.header.Forward[sl.level-1] == nil { + sl.level-- + } + sl.size -= 1 + return true +} + +func (sl *SkipList[T]) Peek() (T, error) { + curr := sl.header + curr = curr.Forward[0] + var zero T + if curr == nil { + return zero, errors.New("跳表为空") + } + return curr.Val, nil +} + +func (sl *SkipList[T]) Get(index int) (T, error) { + var zero T + if index < 0 || index >= sl.size { + return zero, errs.NewErrIndexOutOfRange(sl.size, index) + } + curr := sl.header + for i := 0; i <= index; i++ { + curr = curr.Forward[0] + } + return curr.Val, nil +} diff --git a/internal/list/skip_list_test.go b/internal/list/skip_list_test.go new file mode 100644 index 00000000..e27bace3 --- /dev/null +++ b/internal/list/skip_list_test.go @@ -0,0 +1,459 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package list + +import ( + "errors" + "fmt" + "testing" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/errs" + "github.com/stretchr/testify/assert" +) + +func TestNewSkipList(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + level int + wantHeader *skipListNode[int] + wantLevel int + wantSlice []int + wantErr error + wantSize int + }{ + { + name: "new skip list", + compare: ekit.ComparatorRealNumber[int], + level: 1, + wantLevel: 1, + wantHeader: newSkipListNode[int](0, MaxLevel), + wantSlice: []int{}, + wantSize: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList(tc.compare) + assert.Equal(t, tc.wantLevel, sl.level) + assert.Equal(t, tc.wantHeader, sl.header) + assert.Equal(t, tc.wantSlice, sl.AsSlice()) + assert.Equal(t, tc.wantSize, sl.size) + + }) + } +} + +func TestNewSkipListFromSlice(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + level int + slice []int + + wantSlice []int + wantErr error + wantSize int + }{ + { + name: "new skip list", + compare: ekit.ComparatorRealNumber[int], + level: 1, + slice: []int{1, 2, 3}, + + wantSlice: []int{1, 2, 3}, + wantSize: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipListFromSlice[int](tc.slice, tc.compare) + assert.Equal(t, tc.wantSlice, sl.AsSlice()) + assert.Equal(t, tc.wantSize, sl.size) + + }) + } +} + +//func TestSkipListToSlice(t *testing.T) { +// +//} + +func TestSkipList_DeleteElement(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + value int + wantSlice []int + wantSize int + wantRes bool + }{ + { + name: "delete 2 from [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + value: 2, + wantSlice: []int{1, 3}, + wantSize: 2, + wantRes: true, + }, + { + name: "delete 1 from [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{3}, + wantSize: 1, + wantRes: true, + }, + { + name: "delete 1 from []", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{}, + wantSize: 0, + wantRes: true, + }, + { + name: "delete 1 from [1]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{}, + wantSize: 0, + wantRes: true, + }, + { + name: "delete 1 from [2]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{2}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{2}, + wantSize: 1, + wantRes: true, + }, + { + name: "delete 3 from [1,2,3,4,5,6,7]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3, 4, 5, 6, 7}, ekit.ComparatorRealNumber[int]), + value: 3, + wantSlice: []int{1, 2, 4, 5, 6, 7}, + wantSize: 6, + wantRes: true, + }, + { + name: "delete 8 from [1,2,3,4,5,6,7]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3, 4, 5, 6, 7}, ekit.ComparatorRealNumber[int]), + value: 8, + wantSlice: []int{1, 2, 3, 4, 5, 6, 7}, + wantSize: 7, + wantRes: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ok := tc.skiplist.DeleteElement(tc.value) + assert.Equal(t, tc.wantSize, tc.skiplist.size) + assert.Equal(t, tc.wantSlice, tc.skiplist.AsSlice()) + assert.Equal(t, tc.wantRes, ok) + }) + } +} + +func TestSkipList_Insert(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + value int + wantSlice []int + wantSize int + }{ + { + name: "insert 2 into [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + value: 2, + wantSlice: []int{1, 2, 3}, + wantSize: 3, + }, + { + name: "insert 1 into []", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{1}, + wantSize: 1, + }, + { + name: "insert 2 into [1,2,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + value: 2, + wantSlice: []int{1, 2, 2, 3}, + wantSize: 4, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.skiplist.Insert(tc.value) + assert.Equal(t, tc.wantSize, tc.skiplist.size) + assert.Equal(t, tc.wantSlice, tc.skiplist.AsSlice()) + }) + } +} + +func TestSkipList_Search(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + value int + wantSlice []int + wantSize int + wantRes bool + }{ + { + name: "search 2 from [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + value: 2, + wantSlice: []int{1, 3}, + wantSize: 2, + wantRes: false, + }, + { + name: "search 1 from [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{1, 3}, + wantSize: 2, + wantRes: true, + }, + { + name: "search 1 from []", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{}, + wantSize: 0, + wantRes: false, + }, + { + name: "search 1 from [1]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{1}, + wantSize: 1, + wantRes: true, + }, + { + name: "search 1 from [2]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{2}, ekit.ComparatorRealNumber[int]), + value: 1, + wantSlice: []int{2}, + wantSize: 1, + wantRes: false, + }, + { + name: "search 3 from [1,2,3,4,5,6]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3, 4, 5, 6}, ekit.ComparatorRealNumber[int]), + value: 3, + wantSlice: []int{1, 2, 3, 4, 5, 6}, + wantSize: 6, + wantRes: true, + }, + { + name: "search 8 from [1,2,3,4,5,6]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3, 4, 5, 6}, ekit.ComparatorRealNumber[int]), + value: 8, + wantSlice: []int{1, 2, 3, 4, 5, 6}, + wantSize: 6, + wantRes: false, + }, + { + name: "search 2 from [1,2,2,3,3,4,5,6]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 2, 3, 3, 4, 5, 6}, ekit.ComparatorRealNumber[int]), + value: 2, + wantSlice: []int{1, 2, 2, 3, 3, 4, 5, 6}, + wantSize: 8, + wantRes: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ok := tc.skiplist.Search(tc.value) + assert.Equal(t, tc.wantSize, tc.skiplist.size) + assert.Equal(t, tc.wantSlice, tc.skiplist.AsSlice()) + assert.Equal(t, tc.wantRes, ok) + }) + } +} + +func TestSkipList_randomLevel(t *testing.T) { + sl := NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]) + fmt.Println(sl.randomLevel()) +} + +func TestSkipList_Peek(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + wantSlice []int + wantVal int + wantErr error + }{ + { + name: "peek [1,3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 3}, ekit.ComparatorRealNumber[int]), + wantSlice: []int{1, 3}, + wantVal: 1, + wantErr: nil, + }, + { + name: "peek []", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{}, ekit.ComparatorRealNumber[int]), + wantSlice: []int{}, + wantVal: 0, + wantErr: errors.New("跳表为空"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val, err := tc.skiplist.Peek() + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestSkipList_Get(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + index int + wantSlice []int + wantVal int + wantErr error + }{ + { + name: "get index -1 [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + index: -1, + wantSlice: []int{1, 2, 3}, + wantVal: 0, + wantErr: errs.NewErrIndexOutOfRange(3, -1), + }, + { + name: "get index 3 [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + index: 3, + wantSlice: []int{1, 2, 3}, + wantVal: 0, + wantErr: errs.NewErrIndexOutOfRange(3, 3), + }, + { + name: "get index 0 [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + index: 0, + wantSlice: []int{1, 2, 3}, + wantVal: 1, + wantErr: nil, + }, + { + name: "get index 1 [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + index: 1, + wantSlice: []int{1, 2, 3}, + wantVal: 2, + wantErr: nil, + }, + { + name: "get index 2 [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + index: 2, + wantSlice: []int{1, 2, 3}, + wantVal: 3, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val, err := tc.skiplist.Get(tc.index) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func TestSkipList_AsSlice(t *testing.T) { + testCases := []struct { + name string + skiplist *SkipList[int] + compare ekit.Comparator[int] + wantSlice []int + }{ + { + name: " [1, 2, 3]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + wantSlice: []int{1, 2, 3}, + }, + { + name: "[3,2,1]]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{3, 2, 1}, ekit.ComparatorRealNumber[int]), + wantSlice: []int{1, 2, 3}, + }, + { + name: "[]", + compare: ekit.ComparatorRealNumber[int], + skiplist: NewSkipListFromSlice[int]([]int{1, 2, 3}, ekit.ComparatorRealNumber[int]), + wantSlice: []int{1, 2, 3}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.wantSlice, tc.skiplist.AsSlice()) + }) + } +} diff --git a/internal/slice/add.go b/internal/slice/add.go index c368c50c..2e7a436e 100644 --- a/internal/slice/add.go +++ b/internal/slice/add.go @@ -18,7 +18,7 @@ import "github.com/ecodeclub/ekit/internal/errs" func Add[T any](src []T, element T, index int) ([]T, error) { length := len(src) - if index < 0 || index >= length { + if index < 0 || index > length { return nil, errs.NewErrIndexOutOfRange(length, index) } diff --git a/internal/slice/add_test.go b/internal/slice/add_test.go index 7236aece..8bfd0373 100644 --- a/internal/slice/add_test.go +++ b/internal/slice/add_test.go @@ -63,6 +63,20 @@ func TestAdd(t *testing.T) { index: 5, wantSlice: []int{123, 100, 101, 102, 102, 233, 102}, }, + { + name: "append on last", + slice: []int{123, 100, 101, 102, 102, 102}, + addVal: 233, + index: 6, + wantSlice: []int{123, 100, 101, 102, 102, 102, 233}, + }, + { + name: "index out of range", + slice: []int{123, 100, 101, 102, 102, 102}, + addVal: 233, + index: 7, + wantErr: errs.NewErrIndexOutOfRange(6, 7), + }, } for _, tc := range testCases { diff --git a/iox/json_reader.go b/iox/json_reader.go new file mode 100644 index 00000000..48b4322b --- /dev/null +++ b/iox/json_reader.go @@ -0,0 +1,50 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iox + +import ( + "bytes" + "encoding/json" +) + +type JSONReader struct { + val any + bf *bytes.Reader +} + +func (j *JSONReader) Read(p []byte) (n int, err error) { + if j.bf == nil { + var data []byte + data, err = json.Marshal(j.val) + if err == nil { + j.bf = bytes.NewReader(data) + } + } + if err != nil { + return + } + return j.bf.Read(p) +} + +// NewJSONReader 用于解决将一个结构体序列化为 JSON 之后,再封装为 io.Reader 的场景。 +// 该实现没有做任何输入检查。 +// 也就是你需要自己确保 val 是一个可以被 json 正确处理的东西。 +// 非线程安全。 +// 如果你传入的是 nil,那么读到的结果应该是 null。务必小心。 +func NewJSONReader(val any) *JSONReader { + return &JSONReader{ + val: val, + } +} diff --git a/iox/json_reader_test.go b/iox/json_reader_test.go new file mode 100644 index 00000000..111390da --- /dev/null +++ b/iox/json_reader_test.go @@ -0,0 +1,61 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iox + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJSONReader(t *testing.T) { + testCases := []struct { + name string + input []byte + val any + + wantRes []byte + wantN int + wantErr error + }{ + { + name: "正常读取", + input: make([]byte, 10), + wantN: 10, + val: User{Name: "Tom"}, + wantRes: []byte(`{"name":"T`), + }, + { + name: "输入 nil", + input: make([]byte, 7), + wantN: 4, + wantRes: append([]byte(`null`), 0, 0, 0), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reader := NewJSONReader(tc.val) + n, err := reader.Read(tc.input) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantN, n) + assert.Equal(t, tc.wantRes, tc.input) + }) + } +} + +type User struct { + Name string `json:"name"` +} diff --git a/list/array_list.go b/list/array_list.go index d6dbe90f..6eaa9321 100644 --- a/list/array_list.go +++ b/list/array_list.go @@ -56,14 +56,9 @@ func (a *ArrayList[T]) Append(ts ...T) error { // Add 在ArrayList下标为index的位置插入一个元素 // 当index等于ArrayList长度等同于append -func (a *ArrayList[T]) Add(index int, t T) error { - if index < 0 || index > len(a.vals) { - return errs.NewErrIndexOutOfRange(len(a.vals), index) - } - a.vals = append(a.vals, t) - copy(a.vals[index+1:], a.vals[index:]) - a.vals[index] = t - return nil +func (a *ArrayList[T]) Add(index int, t T) (err error) { + a.vals, err = slice.Add(a.vals, t, index) + return } // Set 设置ArrayList里index位置的值为t diff --git a/list/skip_list.go b/list/skip_list.go new file mode 100644 index 00000000..128c6c33 --- /dev/null +++ b/list/skip_list.go @@ -0,0 +1,54 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package list + +import ( + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/list" +) + +func NewSkipList[T any](compare ekit.Comparator[T]) *SkipList[T] { + pq := &SkipList[T]{} + pq.skiplist = list.NewSkipList[T](compare) + return pq +} + +type SkipList[T any] struct { + skiplist *list.SkipList[T] +} + +func (sl *SkipList[T]) Search(target T) bool { + return sl.skiplist.Search(target) +} + +func (sl *SkipList[T]) AsSlice() []T { + return sl.skiplist.AsSlice() +} + +func (sl *SkipList[T]) Len() int { + return sl.skiplist.Len() +} + +func (sl *SkipList[T]) Cap() int { + return sl.Len() +} + +func (sl *SkipList[T]) Insert(Val T) { + sl.skiplist.Insert(Val) +} + +func (sl *SkipList[T]) DeleteElement(target T) bool { + return sl.skiplist.DeleteElement(target) +} diff --git a/list/skip_list_test.go b/list/skip_list_test.go new file mode 100644 index 00000000..db7564ad --- /dev/null +++ b/list/skip_list_test.go @@ -0,0 +1,172 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package list + +import ( + "testing" + + "github.com/ecodeclub/ekit" + "github.com/stretchr/testify/assert" +) + +func TestNewSkipList(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantSlice []int + }{ + { + name: "new skip list", + compare: ekit.ComparatorRealNumber[int], + wantSlice: []int{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList(tc.compare) + assert.Equal(t, tc.wantSlice, sl.AsSlice()) + }) + } +} + +func TestSkipList_AsSlice(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantSlice []int + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + wantSlice: []int{}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + assert.Equal(t, tc.wantSlice, sl.AsSlice()) + }) + } +} + +func TestSkipList_Cap(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantSize int + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + wantSize: 0, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + assert.Equal(t, tc.wantSize, sl.Cap()) + }) + } +} + +func TestSkipList_DeleteElement(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + value int + wantBool bool + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + value: 1, + wantBool: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + ok := sl.DeleteElement(tc.value) + assert.Equal(t, tc.wantBool, ok) + }) + } +} + +func TestSkipList_Insert(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + key int + wantSlice []int + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + key: 1, + wantSlice: []int{1}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + sl.Insert(tc.key) + assert.Equal(t, tc.wantSlice, sl.AsSlice()) + }) + } +} + +func TestSkipList_Len(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantSize int + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + wantSize: 0, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + assert.Equal(t, tc.wantSize, sl.Len()) + }) + } +} + +func TestSkipList_Search(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + value int + wantBool bool + }{ + { + name: "no err is ok", + compare: ekit.ComparatorRealNumber[int], + value: 1, + wantBool: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sl := NewSkipList[int](tc.compare) + ok := sl.Search(tc.value) + assert.Equal(t, tc.wantBool, ok) + }) + } +} diff --git a/list/types.go b/list/types.go index 5a2182ce..bf5b53b2 100644 --- a/list/types.go +++ b/list/types.go @@ -23,7 +23,9 @@ type List[T any] interface { // Append 在末尾追加元素 Append(ts ...T) error // Add 在特定下标处增加一个新元素 - // 如果下标超出范围,应该返回错误 + // 如果下标不在[0, Len()]范围之内 + // 应该返回错误 + // 如果index == Len()则表示往List末端增加一个值 Add(index int, t T) error // Set 重置 index 位置的值 // 如果下标超出范围,应该返回错误 diff --git a/mapx/builtin_map.go b/mapx/builtin_map.go index a5621d1e..c8484581 100644 --- a/mapx/builtin_map.go +++ b/mapx/builtin_map.go @@ -50,3 +50,7 @@ func newBuiltinMap[K comparable, V any](capacity int) *builtinMap[K, V] { data: make(map[K]V, capacity), } } + +func (b *builtinMap[K, V]) Len() int64 { + return int64(len(b.data)) +} diff --git a/mapx/builtin_map_test.go b/mapx/builtin_map_test.go index 63572e66..1cbfe26c 100644 --- a/mapx/builtin_map_test.go +++ b/mapx/builtin_map_test.go @@ -210,6 +210,42 @@ func TestBuiltinMap_Values(t *testing.T) { } } +func TestBuiltinMap_Len(t *testing.T) { + testCases := []struct { + name string + data map[string]string + + wantLen int64 + }{ + { + name: "got len", + data: map[string]string{ + "key1": "val1", + "key2": "val2", + "key3": "val3", + "key4": "val4", + }, + wantLen: 4, + }, + { + name: "empty map", + data: map[string]string{}, + wantLen: 0, + }, + { + name: "nil map", + wantLen: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + m := builtinMapOf[string, string](tc.data) + assert.Equal(t, tc.wantLen, m.Len()) + }) + } +} + func builtinMapOf[K comparable, V any](data map[K]V) *builtinMap[K, V] { return &builtinMap[K, V]{data: data} } diff --git a/mapx/hashmap.go b/mapx/hashmap.go index f2a7a6a5..01ea17dc 100644 --- a/mapx/hashmap.go +++ b/mapx/hashmap.go @@ -158,3 +158,7 @@ func (n *node[T, ValType]) formatting() { n.value = val n.next = nil } + +func (m *HashMap[T, ValType]) Len() int64 { + return int64(len(m.hashmap)) +} diff --git a/mapx/hashmap_test.go b/mapx/hashmap_test.go index 51161644..ef4aabf3 100644 --- a/mapx/hashmap_test.go +++ b/mapx/hashmap_test.go @@ -469,6 +469,77 @@ func TestHashMap_Keys_Values(t *testing.T) { } } +func TestHashMap_Len(t *testing.T) { + testCases := []struct { + name string + genHashMap func() *HashMap[testData, int] + wantLen int64 + }{ + { + name: "empty", + genHashMap: func() *HashMap[testData, int] { + return NewHashMap[testData, int](10) + }, + wantLen: 0, + }, + { + name: "single key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + err := testHashMap.Put(newTestData(1), 1) + require.NoError(t, err) + return testHashMap + }, + wantLen: 1, + }, + { + name: "multiple keys", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + for _, val := range []int{1, 2} { + err := testHashMap.Put(newTestData(val), val) + require.NoError(t, err) + } + return testHashMap + }, + wantLen: 2, + }, + { + name: "same key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + err := testHashMap.Put(newTestData(1), 1) + require.NoError(t, err) + // 验证id相同,覆盖的场景 + err = testHashMap.Put(newTestData(1), 11) + require.NoError(t, err) + return testHashMap + }, + wantLen: 1, + }, + { + name: "multi with same key", + genHashMap: func() *HashMap[testData, int] { + testHashMap := NewHashMap[testData, int](10) + for _, val := range []int{1, 2} { + // val为1、2 + err := testHashMap.Put(newTestData(val), val*10) + require.NoError(t, err) + } + err := testHashMap.Put(newTestData(1), 11) + require.NoError(t, err) + return testHashMap + }, + wantLen: 2, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.wantLen, tc.genHashMap().Len()) + }) + } +} + type testData struct { id int } diff --git a/mapx/linkedmap.go b/mapx/linkedmap.go index 127f32af..374dbf42 100644 --- a/mapx/linkedmap.go +++ b/mapx/linkedmap.go @@ -110,3 +110,7 @@ func (l *LinkedMap[K, V]) Values() []V { } return values } + +func (l *LinkedMap[K, V]) Len() int64 { + return int64(l.length) +} diff --git a/mapx/linkedmap_test.go b/mapx/linkedmap_test.go index a0060f3a..36e96a8b 100644 --- a/mapx/linkedmap_test.go +++ b/mapx/linkedmap_test.go @@ -467,3 +467,39 @@ func TestLinkedMap_PutAndDelete(t *testing.T) { }) } } + +func TestLinkedMap_Len(t *testing.T) { + testCases := []struct { + name string + linkedMap func(t *testing.T) *LinkedMap[int, int] + + wantLen int64 + }{ + { + name: "empty linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, _ := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + return linkedTreeMap + }, + + wantLen: 0, + }, + { + name: "not empty linked map", + linkedMap: func(t *testing.T) *LinkedMap[int, int] { + linkedTreeMap, _ := NewLinkedTreeMap[int, int](ekit.ComparatorRealNumber[int]) + assert.NoError(t, linkedTreeMap.Put(1, 1)) + assert.NoError(t, linkedTreeMap.Put(2, 2)) + assert.NoError(t, linkedTreeMap.Put(3, 3)) + return linkedTreeMap + }, + + wantLen: 3, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantLen, tt.linkedMap(t).Len()) + }) + } +} diff --git a/mapx/map.go b/mapx/map.go index 0d3f8678..78e838ff 100644 --- a/mapx/map.go +++ b/mapx/map.go @@ -14,6 +14,8 @@ package mapx +import "fmt" + // Keys 返回 map 里面的所有的 key。 // 需要注意:这些 key 的顺序是随机。 func Keys[K comparable, V any](m map[K]V) []K { @@ -45,3 +47,28 @@ func KeysValues[K comparable, V any](m map[K]V) ([]K, []V) { } return keys, values } + +// ToMap 将会返回一个map[K]V +// 请保证传入的 keys 与 values 长度相同,长度均为n +// 长度不相同或者 keys 或者 values 为nil则会抛出异常 +// 返回的 m map[K]V 保证对于所有的 0 <= i < n +// m[keys[i]] = values[i] +// +// 注意: +// 如果传入的数组中存在 0 <= i < j < n使得 keys[i] == keys[j] +// 则在返回的 m 中 m[keys[i]] = values[j] +// 如果keys和values的长度为0,则会返回一个空map +func ToMap[K comparable, V any](keys []K, values []V) (m map[K]V, err error) { + if keys == nil || values == nil { + return nil, fmt.Errorf("keys与values均不可为nil") + } + n := len(keys) + if n != len(values) { + return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values)) + } + m = make(map[K]V, n) + for i := 0; i < n; i++ { + m[keys[i]] = values[i] + } + return +} diff --git a/mapx/map_test.go b/mapx/map_test.go index 569ac2a2..0d847013 100644 --- a/mapx/map_test.go +++ b/mapx/map_test.go @@ -15,6 +15,7 @@ package mapx import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -147,3 +148,63 @@ func TestKeysValues(t *testing.T) { }) } } + +func TestToMap(t *testing.T) { + type caseType struct { + keys []int + values []string + + result map[int]string + err error + } + for _, c := range []caseType{ + { + keys: []int{1, 2, 3}, + values: []string{"1", "2", "3"}, + result: map[int]string{ + 1: "1", + 2: "2", + 3: "3", + }, + err: nil, + }, + { + keys: []int{1, 2, 3}, + values: []string{"1", "2"}, + result: nil, + err: fmt.Errorf("keys与values的长度不同, len(keys)=3, len(values)=2"), + }, + { + keys: []int{1, 2, 3}, + values: nil, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: []string{"1", "2"}, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: nil, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1, 2, 3, 1, 1}, + values: []string{"1", "2", "3", "10", "100"}, + result: map[int]string{ + 1: "100", + 2: "2", + 3: "3", + }, + err: nil, + }, + } { + result, err := ToMap(c.keys, c.values) + assert.Equal(t, c.err, err) + assert.Equal(t, c.result, result) + } +} diff --git a/mapx/multi_map.go b/mapx/multi_map.go index f6497164..32668e05 100644 --- a/mapx/multi_map.go +++ b/mapx/multi_map.go @@ -93,3 +93,8 @@ func (m *MultiMap[K, V]) Values() [][]V { } return copyValues } + +// Len 返回 MultiMap 键值对的数量 +func (m *MultiMap[K, V]) Len() int64 { + return m.m.Len() +} diff --git a/mapx/multi_map_test.go b/mapx/multi_map_test.go index 0875763a..a81697ea 100644 --- a/mapx/multi_map_test.go +++ b/mapx/multi_map_test.go @@ -594,3 +594,81 @@ func TestMultiMap_PutMany(t *testing.T) { }) } } + +func TestMultiMap_Len(t *testing.T) { + testCases := []struct { + name string + multiTreeMap *MultiMap[int, int] + multiHashMap *MultiMap[testData, int] + wantLen int64 + }{ + { + name: "empty", + multiTreeMap: getMultiTreeMap(), + multiHashMap: getMultiHashMap(), + + wantLen: 0, + }, + { + name: "single", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + return multiHashMap + }(), + + wantLen: 1, + }, + { + name: "multiple", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(2, 2) + _ = multiTreeMap.Put(3, 3) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 2}, 2) + _ = multiHashMap.Put(testData{id: 3}, 3) + return multiHashMap + }(), + + wantLen: 3, + }, + { + name: "multiple with same key", + multiTreeMap: func() *MultiMap[int, int] { + multiTreeMap := getMultiTreeMap() + _ = multiTreeMap.Put(1, 1) + _ = multiTreeMap.Put(1, 2) + _ = multiTreeMap.Put(1, 3) + return multiTreeMap + }(), + multiHashMap: func() *MultiMap[testData, int] { + multiHashMap := getMultiHashMap() + _ = multiHashMap.Put(testData{id: 1}, 1) + _ = multiHashMap.Put(testData{id: 1}, 2) + _ = multiHashMap.Put(testData{id: 1}, 3) + return multiHashMap + }(), + wantLen: 1, + }, + } + for _, tt := range testCases { + t.Run("MultiTreeMap", func(t *testing.T) { + assert.Equal(t, tt.wantLen, tt.multiTreeMap.Len()) + }) + + t.Run("MultiHashMap", func(t *testing.T) { + assert.Equal(t, tt.wantLen, tt.multiHashMap.Len()) + }) + } +} diff --git a/mapx/treemap.go b/mapx/treemap.go index 719d696d..40f611fa 100644 --- a/mapx/treemap.go +++ b/mapx/treemap.go @@ -96,4 +96,9 @@ func (treeMap *TreeMap[T, V]) Values() []V { return vals } +// Len 返回了键值对的数量 +func (treeMap *TreeMap[T, V]) Len() int64 { + return int64(treeMap.tree.Size()) +} + var _ mapi[any, any] = (*TreeMap[any, any])(nil) diff --git a/mapx/treemap_test.go b/mapx/treemap_test.go index 378000fd..c224631a 100644 --- a/mapx/treemap_test.go +++ b/mapx/treemap_test.go @@ -409,6 +409,42 @@ func TestTreeMap_Delete(t *testing.T) { } } +func TestTreeMap_Len(t *testing.T) { + var tests = []struct { + name string + m map[int]int + len int64 + }{ + { + name: "empty-TreeMap", + m: map[int]int{}, + len: 0, + }, + { + name: "find", + m: map[int]int{ + 1: 1, + 2: 2, + 0: 0, + 3: 3, + 5: 5, + 4: 4, + }, + len: 6, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + treeMap, _ := NewTreeMap[int, int](compare()) + for k, v := range tt.m { + err := treeMap.Put(k, v) + require.NoError(t, err) + } + assert.Equal(t, tt.len, treeMap.Len()) + }) + } +} + func compare() ekit.Comparator[int] { return ekit.ComparatorRealNumber[int] } diff --git a/mapx/types.go b/mapx/types.go index 8b43c7aa..88806fb0 100644 --- a/mapx/types.go +++ b/mapx/types.go @@ -29,4 +29,6 @@ type mapi[K any, V any] interface { // 注意,当你调用多次拿到的结果不一定相等 // 取决于具体实现 Values() []V + // 返回键值对数量 + Len() int64 } diff --git a/net/httpx/httptestx/recorder.go b/net/httpx/httptestx/recorder.go new file mode 100644 index 00000000..1eee3562 --- /dev/null +++ b/net/httpx/httptestx/recorder.go @@ -0,0 +1,44 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptestx + +import ( + "encoding/json" + "net/http/httptest" +) + +type JSONResponseRecorder[T any] struct { + *httptest.ResponseRecorder +} + +func NewJSONResponseRecorder[T any]() *JSONResponseRecorder[T] { + return &JSONResponseRecorder[T]{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +func (r JSONResponseRecorder[T]) Scan() (T, error) { + var t T + err := json.NewDecoder(r.Body).Decode(&t) + return t, err +} + +func (r JSONResponseRecorder[T]) MustScan() T { + t, err := r.Scan() + if err != nil { + panic(err) + } + return t +} diff --git a/net/httpx/httptestx/recorder_test.go b/net/httpx/httptestx/recorder_test.go new file mode 100644 index 00000000..1e0a00f4 --- /dev/null +++ b/net/httpx/httptestx/recorder_test.go @@ -0,0 +1,41 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httptestx + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONResponseRecorder_MustScan(t *testing.T) { + // 成功案例 + recorder := NewJSONResponseRecorder[User]() + _, err := recorder.WriteString(`{"name": "Tom"}`) + require.NoError(t, err) + u := recorder.MustScan() + assert.Equal(t, User{Name: "Tom"}, u) + + // panic 案例 + recorder = NewJSONResponseRecorder[User]() + assert.Panics(t, func() { + recorder.MustScan() + }) +} + +type User struct { + Name string `json:"name"` +} diff --git a/net/httpx/log_round_trip.go b/net/httpx/log_round_trip.go new file mode 100644 index 00000000..2dca4be6 --- /dev/null +++ b/net/httpx/log_round_trip.go @@ -0,0 +1,67 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "bytes" + "io" + "net/http" +) + +type LogRoundTrip struct { + delegate http.RoundTripper + // l 绝对不会为 nil + log func(l Log, err error) +} + +func NewLogRoundTrip(rp http.RoundTripper, log func(l Log, err error)) *LogRoundTrip { + return &LogRoundTrip{ + delegate: rp, + log: log, + } +} + +func (l *LogRoundTrip) RoundTrip(request *http.Request) (resp *http.Response, err error) { + log := Log{ + URL: request.URL.String(), + } + defer func() { + if resp != nil { + log.RespStatus = resp.Status + if resp.Body != nil { + // 出现 error 了这里也不知道怎么处理,暂时忽略 + body, _ := io.ReadAll(resp.Body) + resp.Body = io.NopCloser(bytes.NewReader(body)) + log.RespBody = string(body) + } + } + l.log(log, err) + }() + if request.Body != nil { + // 出现 error 了这里也不知道怎么处理,暂时忽略 + body, _ := io.ReadAll(request.Body) + request.Body = io.NopCloser(bytes.NewReader(body)) + log.ReqBody = string(body) + } + resp, err = l.delegate.RoundTrip(request) + return +} + +type Log struct { + URL string + ReqBody string + RespBody string + RespStatus string +} diff --git a/net/httpx/log_round_trip_test.go b/net/httpx/log_round_trip_test.go new file mode 100644 index 00000000..80d1b72d --- /dev/null +++ b/net/httpx/log_round_trip_test.go @@ -0,0 +1,57 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLogRoundTrip(t *testing.T) { + client := http.DefaultClient + var acceptLog Log + var acceptError error + client.Transport = NewLogRoundTrip(&doNothingRoundTrip{}, func(l Log, err error) { + acceptLog = l + acceptError = err + }) + NewRequest(context.Background(), + http.MethodGet, "http://localhost/test"). + JSONBody(User{Name: "Tom"}). + Client(client). + Do() + assert.Equal(t, nil, acceptError) + assert.Equal(t, Log{ + URL: "http://localhost/test", + ReqBody: `{"Name":"Tom"}`, + RespBody: "resp body", + RespStatus: "200 OK", + }, acceptLog) +} + +type doNothingRoundTrip struct { +} + +func (d *doNothingRoundTrip) RoundTrip(request *http.Request) (*http.Response, error) { + return &http.Response{ + Status: "200 OK", + Body: io.NopCloser(bytes.NewBuffer([]byte("resp body"))), + }, nil +} diff --git a/net/httpx/request.go b/net/httpx/request.go new file mode 100644 index 00000000..e47e4d94 --- /dev/null +++ b/net/httpx/request.go @@ -0,0 +1,77 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "context" + "io" + "net/http" + + "github.com/ecodeclub/ekit/iox" +) + +type Request struct { + req *http.Request + err error + client *http.Client +} + +func NewRequest(ctx context.Context, method, url string) *Request { + req, err := http.NewRequestWithContext(ctx, method, url, nil) + return &Request{ + req: req, + err: err, + client: http.DefaultClient, + } +} + +// JSONBody 使用 JSON body +func (req *Request) JSONBody(val any) *Request { + req.req.Body = io.NopCloser(iox.NewJSONReader(val)) + req.req.Header.Set("Content-Type", "application/json") + return req +} + +func (req *Request) Client(cli *http.Client) *Request { + req.client = cli + return req +} + +func (req *Request) AddHeader(key string, value string) *Request { + req.req.Header.Add(key, value) + return req +} + +// AddParam 添加查询参数 +// 这个方法性能不好,但是好用 +func (req *Request) AddParam(key string, value string) *Request { + q := req.req.URL.Query() + q.Add(key, value) + req.req.URL.RawQuery = q.Encode() + return req +} + +func (req *Request) Do() *Response { + if req.err != nil { + return &Response{ + err: req.err, + } + } + resp, err := req.client.Do(req.req) + return &Response{ + Response: resp, + err: err, + } +} diff --git a/net/httpx/request_test.go b/net/httpx/request_test.go new file mode 100644 index 00000000..1ed05e73 --- /dev/null +++ b/net/httpx/request_test.go @@ -0,0 +1,118 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "context" + "errors" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequest_Client(t *testing.T) { + req := NewRequest(context.Background(), http.MethodPost, "/abc") + assert.Equal(t, http.DefaultClient, req.client) + cli := &http.Client{} + req = req.Client(&http.Client{}) + assert.Equal(t, cli, req.client) +} + +func TestRequest_JSONBody(t *testing.T) { + req := NewRequest(context.Background(), http.MethodPost, "/abc") + assert.Nil(t, req.req.Body) + req = req.JSONBody(User{}) + assert.NotNil(t, req.req.Body) + assert.Equal(t, "application/json", req.req.Header.Get("Content-Type")) +} + +func TestRequest_Do(t *testing.T) { + l, err := net.Listen("unix", "/tmp/test.sock") + require.NoError(t, err) + server := http.Server{} + go func() { + http.HandleFunc("/hello", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("OK")) + }) + _ = server.Serve(l) + }() + defer func() { + _ = l.Close() + }() + testCases := []struct { + name string + req func() *Request + wantErr error + }{ + { + name: "构造请求的时候有 error", + req: func() *Request { + return &Request{ + err: errors.New("mock error"), + } + }, + wantErr: errors.New("mock error"), + }, + { + name: "成功", + req: func() *Request { + req := NewRequest(context.Background(), http.MethodGet, "http://localhost:8081/hello") + return req.Client(&http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, + network, addr string) (net.Conn, error) { + return net.Dial("unix", "/tmp/test.sock") + }, + }, + }) + }, + }, + } + + // 确保前面的 http 端口启动成功 + time.Sleep(time.Second) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := tc.req() + resp := req.Do() + assert.Equal(t, tc.wantErr, resp.err) + }) + } +} + +func TestRequest_AddParam(t *testing.T) { + req := NewRequest(context.Background(), + http.MethodGet, "http://localhost"). + AddParam("key1", "value1"). + AddParam("key2", "value2") + assert.Equal(t, "http://localhost?key1=value1&key2=value2", req.req.URL.String()) +} + +func TestRequestAddHeader(t *testing.T) { + req := NewRequest(context.Background(), + http.MethodGet, "http://localhost"). + AddHeader("head1", "val1").AddHeader("head1", "val2") + vals := req.req.Header.Values("head1") + assert.Equal(t, []string{"val1", "val2"}, vals) +} + +type User struct { + Name string +} diff --git a/net/httpx/response.go b/net/httpx/response.go new file mode 100644 index 00000000..be6c4374 --- /dev/null +++ b/net/httpx/response.go @@ -0,0 +1,34 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "encoding/json" + "net/http" +) + +type Response struct { + *http.Response + err error +} + +// JSONScan 将 Body 按照 JSON 反序列化为结构体 +func (r *Response) JSONScan(val any) error { + if r.err != nil { + return r.err + } + err := json.NewDecoder(r.Body).Decode(val) + return err +} diff --git a/net/httpx/response_test.go b/net/httpx/response_test.go new file mode 100644 index 00000000..a1e2fa24 --- /dev/null +++ b/net/httpx/response_test.go @@ -0,0 +1,54 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpx + +import ( + "io" + "net/http" + "testing" + + "github.com/ecodeclub/ekit/iox" + "github.com/stretchr/testify/assert" +) + +func TestResponse_JSONScan(t *testing.T) { + testCases := []struct { + name string + resp *Response + wantVal User + wantErr error + }{ + { + name: "scan成功", + resp: &Response{ + Response: &http.Response{ + Body: io.NopCloser(iox.NewJSONReader(User{Name: "Tom"})), + }, + }, + wantVal: User{ + Name: "Tom", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var u User + err := tc.resp.JSONScan(&u) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantVal, u) + }) + } +} diff --git a/queue/errs.go b/queue/errs.go new file mode 100644 index 00000000..0663deda --- /dev/null +++ b/queue/errs.go @@ -0,0 +1,20 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import "github.com/ecodeclub/ekit/internal/queue" + +// ErrOutOfCapacity 超过容量 +var ErrOutOfCapacity = queue.ErrOutOfCapacity diff --git a/queue/priority_queue.go b/queue/priority_queue.go new file mode 100644 index 00000000..41c41578 --- /dev/null +++ b/queue/priority_queue.go @@ -0,0 +1,46 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/queue" +) + +type PriorityQueue[T any] struct { + priorityQueue *queue.PriorityQueue[T] +} + +func NewPriorityQueue[T any](capacity int, compare ekit.Comparator[T]) *PriorityQueue[T] { + pq := &PriorityQueue[T]{} + pq.priorityQueue = queue.NewPriorityQueue[T](capacity, compare) + return pq +} + +func (pq *PriorityQueue[T]) Len() int { + return pq.priorityQueue.Len() +} + +func (pq *PriorityQueue[T]) Peek() (T, error) { + return pq.priorityQueue.Peek() +} + +func (pq *PriorityQueue[T]) Enqueue(t T) error { + return pq.priorityQueue.Enqueue(t) +} + +func (pq *PriorityQueue[T]) Dequeue() (T, error) { + return pq.priorityQueue.Dequeue() +} diff --git a/queue/priority_queue_test.go b/queue/priority_queue_test.go new file mode 100644 index 00000000..ed7897ae --- /dev/null +++ b/queue/priority_queue_test.go @@ -0,0 +1,147 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "testing" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/queue" + "github.com/stretchr/testify/assert" +) + +func compare() ekit.Comparator[int] { + return ekit.ComparatorRealNumber[int] +} + +func TestNewPriorityQueue(t *testing.T) { + testCases := []struct { + name string + initSize int + compare ekit.Comparator[int] + wantErr error + }{ + { + name: "compare is nil", + initSize: 8, + compare: nil, + }, + { + name: "compare is ok", + initSize: 8, + compare: compare(), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _ = NewPriorityQueue[int](tc.initSize, tc.compare) + }) + } +} + +func TestPriorityQueue_Len(t *testing.T) { + testCases := []struct { + name string + initSize int + compare ekit.Comparator[int] + wantLen int + }{ + { + name: "no err is ok", + initSize: 8, + compare: compare(), + wantLen: 0, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pq := NewPriorityQueue[int](tc.initSize, tc.compare) + assert.Equal(t, tc.wantLen, pq.Len()) + }) + } +} + +func TestPriorityQueue_Peek(t *testing.T) { + testCases := []struct { + name string + initSize int + compare ekit.Comparator[int] + wantResult int + wantErr error + }{ + { + name: "no err is ok", + initSize: 8, + compare: compare(), + wantErr: queue.ErrEmptyQueue, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pq := NewPriorityQueue[int](tc.initSize, tc.compare) + result, err := pq.Peek() + assert.Equal(t, tc.wantResult, result) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestPriorityQueue_Enqueue(t *testing.T) { + testCases := []struct { + name string + initSize int + compare ekit.Comparator[int] + enqueueData int + wantErr error + }{ + { + name: "no err is ok", + initSize: 8, + compare: compare(), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pq := NewPriorityQueue[int](tc.initSize, tc.compare) + err := pq.Enqueue(tc.enqueueData) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestPriorityQueue_Dequeue(t *testing.T) { + testCases := []struct { + name string + initSize int + compare ekit.Comparator[int] + wantResult int + wantErr error + }{ + { + name: "no err is ok", + initSize: 8, + compare: compare(), + wantErr: queue.ErrEmptyQueue, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pq := NewPriorityQueue[int](tc.initSize, tc.compare) + result, err := pq.Dequeue() + assert.Equal(t, tc.wantResult, result) + assert.Equal(t, tc.wantErr, err) + }) + } +} diff --git a/randx/rand_code.go b/randx/rand_code.go index 776d0559..70849bf8 100644 --- a/randx/rand_code.go +++ b/randx/rand_code.go @@ -17,36 +17,95 @@ package randx import ( "errors" "math/rand" + + "github.com/ecodeclub/ekit/tuple/pair" ) -var ERRTYPENOTSUPPORTTED = errors.New("ekit:不支持的类型") +var ( + errTypeNotSupported = errors.New("ekit:不支持的类型") + errLengthLessThanZero = errors.New("ekit:长度必须大于等于0") +) -type TYPE int +type Type int const ( - TYPE_DEFAULT TYPE = 0 //默认类型 - TYPE_DIGIT TYPE = 1 //数字// - TYPE_LETTER TYPE = 2 //小写字母 - TYPE_CAPITAL TYPE = 3 //大写字母 - TYPE_MIXED TYPE = 4 //数字+字母混合 + // TypeDigit 数字 + TypeDigit Type = 1 + // TypeLowerCase 小写字母 + TypeLowerCase Type = 1 << 1 + // TypeUpperCase 大写字母 + TypeUpperCase Type = 1 << 2 + // TypeSpecial 特殊符号 + TypeSpecial Type = 1 << 3 + // TypeMixed 混合类型 + TypeMixed = (TypeDigit | TypeUpperCase | TypeLowerCase | TypeSpecial) + + // CharsetDigit 数字字符组 + CharsetDigit = "0123456789" + // CharsetLowerCase 小写字母字符组 + CharsetLowerCase = "abcdefghijklmnopqrstuvwxyz" + // CharsetUpperCase 大写字母字符组 + CharsetUpperCase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + // CharsetSpecial 特殊字符数组 + CharsetSpecial = " ~!@#$%^&*()_+-=[]{};'\\:\"|,./<>?" +) + +var ( + // 只限于randx包内部使用 + typeCharsetPairs = []pair.Pair[Type, string]{ + pair.NewPair(TypeDigit, CharsetDigit), + pair.NewPair(TypeLowerCase, CharsetLowerCase), + pair.NewPair(TypeUpperCase, CharsetUpperCase), + pair.NewPair(TypeSpecial, CharsetSpecial), + } ) -// RandCode 根据传入的长度和类型生成随机字符串,这个方法目前可以生成数字、字母、数字+字母的随机字符串 -func RandCode(length int, typ TYPE) (string, error) { - switch typ { - case TYPE_DEFAULT: - fallthrough - case TYPE_DIGIT: - return generate("0123456789", length, 4), nil - case TYPE_LETTER: - return generate("abcdefghijklmnopqrstuvwxyz", length, 5), nil - case TYPE_CAPITAL: - return generate("ABCDEFGHIJKLMNOPQRSTUVWXYZ", length, 5), nil - case TYPE_MIXED: - return generate("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", length, 7), nil - default: - return "", ERRTYPENOTSUPPORTTED +// RandCode 根据传入的长度和类型生成随机字符串 +// 请保证输入的 length >= 0,否则会返回 errLengthLessThanZero +// 请保证输入的 typ 的取值范围在 (0, type.MIXED] 内,否则会返回 errTypeNotSupported +func RandCode(length int, typ Type) (string, error) { + if length < 0 { + return "", errLengthLessThanZero + } + if length == 0 { + return "", nil + } + if typ > TypeMixed { + return "", errTypeNotSupported + } + charset := "" + for _, p := range typeCharsetPairs { + if (typ & p.Key) == p.Key { + charset += p.Value + } + } + return RandStrByCharset(length, charset) +} + +// RandStrByCharset 根据传入的长度和字符集生成随机字符串 +// 请保证输入的 length >= 0,否则会返回 errLengthLessThanZero +// 请保证输入的字符集不为空字符串,否则会返回 errTypeNotSupported +// 字符集内部字符可以无序或重复 +func RandStrByCharset(length int, charset string) (string, error) { + if length < 0 { + return "", errLengthLessThanZero + } + if length == 0 { + return "", nil + } + charsetSize := len(charset) + if charsetSize == 0 { + return "", errTypeNotSupported + } + return generate(charset, length, getFirstMask(charsetSize)), nil +} + +func getFirstMask(charsetSize int) int { + bits := 0 + for charsetSize > ((1 << bits) - 1) { + bits++ } + return bits } // generate 根据传入的随机源和长度生成随机字符串,一次随机,多次使用 diff --git a/randx/rand_code_test.go b/randx/rand_code_test.go index ee962eee..9db13b8c 100644 --- a/randx/rand_code_test.go +++ b/randx/rand_code_test.go @@ -12,81 +12,212 @@ // See the License for the specific language governing permissions and // limitations under the License. -package randx +package randx_test import ( "errors" "regexp" + "strings" "testing" + + "github.com/ecodeclub/ekit/randx" + "github.com/stretchr/testify/assert" +) + +var ( + errTypeNotSupported = errors.New("ekit:不支持的类型") + errLengthLessThanZero = errors.New("ekit:长度必须大于等于0") ) func TestRandCode(t *testing.T) { testCases := []struct { name string length int - typ TYPE + typ randx.Type wantMatch string wantErr error }{ { - name: "默认类型", - length: 8, - typ: TYPE_DEFAULT, + name: "数字验证码", + length: 100, + typ: randx.TypeDigit, wantMatch: "^[0-9]+$", wantErr: nil, }, { - name: "数字验证码", - length: 8, - typ: TYPE_DIGIT, - wantMatch: "^[0-9]+$", - wantErr: nil, - }, { name: "小写字母验证码", - length: 8, - typ: TYPE_LETTER, + length: 100, + typ: randx.TypeLowerCase, wantMatch: "^[a-z]+$", wantErr: nil, - }, { + }, + { + name: "数字+小写字母验证码", + length: 100, + typ: randx.TypeDigit | randx.TypeLowerCase, + wantMatch: "^[a-z0-9]+$", + wantErr: nil, + }, + { + name: "数字+大写字母验证码", + length: 100, + typ: randx.TypeDigit | randx.TypeUpperCase, + wantMatch: "^[A-Z0-9]+$", + wantErr: nil, + }, + { name: "大写字母验证码", - length: 8, - typ: TYPE_CAPITAL, + length: 100, + typ: randx.TypeUpperCase, wantMatch: "^[A-Z]+$", wantErr: nil, - }, { - name: "混合验证码", - length: 8, - typ: TYPE_MIXED, + }, + { + name: "大小写字母验证码", + length: 100, + typ: randx.TypeUpperCase | randx.TypeLowerCase, + wantMatch: "^[a-zA-Z]+$", + wantErr: nil, + }, + { + name: "数字+大小写字母验证码", + length: 100, + typ: randx.TypeDigit | randx.TypeUpperCase | randx.TypeLowerCase, wantMatch: "^[0-9a-zA-Z]+$", wantErr: nil, - }, { - name: "未定义类型", - length: 8, - typ: 9, + }, + { + name: "所有类型验证", + length: 100, + typ: randx.TypeMixed, + wantMatch: "^[\\S\\s]+$", + wantErr: nil, + }, + { + name: "特殊字符类型验证", + length: 100, + typ: randx.TypeSpecial, + wantMatch: "^[^0-9a-zA-Z]+$", + wantErr: nil, + }, + { + name: "未定义类型(超过范围)", + length: 100, + typ: randx.TypeMixed + 1, + wantMatch: "", + wantErr: errTypeNotSupported, + }, + { + name: "未定义类型(0)", + length: 100, + typ: 0, wantMatch: "", - wantErr: ERRTYPENOTSUPPORTTED, + wantErr: errTypeNotSupported, + }, + { + name: "长度小于0", + length: -1, + typ: 0, + wantMatch: "", + wantErr: errLengthLessThanZero, + }, + { + name: "长度等于0", + length: 0, + typ: randx.TypeMixed, + wantMatch: "", + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + code, err := randx.RandCode(tc.length, tc.typ) + if tc.wantErr != nil { + assert.Equal(t, tc.wantErr, err) + return + } + assert.Len(t, code, tc.length) + if tc.length > 0 { + matched, err := regexp.MatchString(tc.wantMatch, code) + assert.Nil(t, err) + assert.Truef(t, matched, "expected %s but got %s", tc.wantMatch, code) + } + + }) + } +} + +func TestRandStrByCharset(t *testing.T) { + matchFunc := func(str, charset string) bool { + for _, c := range str { + if !strings.Contains(charset, string(c)) { + return false + } + } + return true + } + testCases := []struct { + name string + length int + charset string + wantErr error + }{ + { + name: "长度小于0", + length: -1, + charset: "123", + wantErr: errLengthLessThanZero, + }, + { + name: "长度等于0", + length: 0, + charset: "123", + wantErr: nil, + }, + { + name: "随机字符串测试", + length: 100, + charset: "2rg248ry227t@@", + wantErr: nil, + }, + { + name: "随机字符串测试", + length: 100, + charset: "2rg248ry227t@&*($.!", + wantErr: nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - code, err := RandCode(tc.length, tc.typ) - if err != nil { - if !errors.Is(err, tc.wantErr) { - t.Errorf("unexpected error: %v", err) - } - } else { - //长度检验 - if len(code) != tc.length { - t.Errorf("expected length: %d but got length:%d ", tc.length, len(code)) - } - //模式检验 - matched, _ := regexp.MatchString(tc.wantMatch, code) - if !matched { - t.Errorf("expected %s but got %s", tc.wantMatch, code) - } + code, err := randx.RandStrByCharset(tc.length, tc.charset) + if tc.wantErr != nil { + assert.Equal(t, tc.wantErr, err) + return } + + assert.Len(t, code, tc.length) + if tc.length > 0 { + assert.True(t, matchFunc(code, tc.charset)) + } + }) } +} +// goos: linux +// goarch: amd64 +// pkg: github.com/ecodeclub/ekit/randx +// cpu: 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz +// BenchmarkRandCode_MIXED/length=1000000-8 1000000000 0.004584 ns/op 0 B/op 0 allocs/op +func BenchmarkRandCode_MIXED(b *testing.B) { + b.Run("length=1000000", func(b *testing.B) { + n := 1000000 + b.StartTimer() + res, err := randx.RandCode(n, randx.TypeMixed) + b.StopTimer() + assert.Nil(b, err) + assert.Len(b, res, n) + }) } diff --git a/slice/add.go b/slice/add.go index a553a9f4..ed61cba2 100644 --- a/slice/add.go +++ b/slice/add.go @@ -17,7 +17,8 @@ package slice import "github.com/ecodeclub/ekit/internal/slice" // Add 在index处添加元素 -// index 范围应为[0, len(src)) +// index 范围应为[0, len(src)] +// 如果index == len(src) 则表示往末尾添加元素 func Add[Src any](src []Src, element Src, index int) ([]Src, error) { res, err := slice.Add[Src](src, element, index) return res, err diff --git a/slice/map.go b/slice/map.go index 11ce496f..9febe56b 100644 --- a/slice/map.go +++ b/slice/map.go @@ -36,6 +36,57 @@ func Map[Src any, Dst any](src []Src, m func(idx int, src Src) Dst) []Dst { return dst } +// 将[]Ele映射到map[Key]Ele +// 从Ele中提取Key的函数fn由使用者提供 +// +// 注意: +// 如果出现 i < j +// 设: +// +// key_i := fn(elements[i]) +// key_j := fn(elements[j]) +// +// 满足key_i == key_j 的情况,则在返回结果的resultMap中 +// resultMap[key_i] = val_j +// +// 即使传入的字符串为nil,也保证返回的map是一个空map而不是nil +func ToMap[Ele any, Key comparable]( + elements []Ele, + fn func(element Ele) Key, +) map[Key]Ele { + return ToMapV( + elements, + func(element Ele) (Key, Ele) { + return fn(element), element + }) +} + +// 将[]Ele映射到map[Key]Val +// 从Ele中提取Key和Val的函数fn由使用者提供 +// +// 注意: +// 如果出现 i < j +// 设: +// +// key_i, val_i := fn(elements[i]) +// key_j, val_j := fn(elements[j]) +// +// 满足key_i == key_j 的情况,则在返回结果的resultMap中 +// resultMap[key_i] = val_j +// +// 即使传入的字符串为nil,也保证返回的map是一个空map而不是nil +func ToMapV[Ele any, Key comparable, Val any]( + elements []Ele, + fn func(element Ele) (Key, Val), +) (resultMap map[Key]Val) { + resultMap = make(map[Key]Val, len(elements)) + for _, element := range elements { + k, v := fn(element) + resultMap[k] = v + } + return +} + // 构造map func toMap[T comparable](src []T) map[T]struct{} { var dataMap = make(map[T]struct{}, len(src)) diff --git a/slice/map_test.go b/slice/map_test.go index a9473746..1af40581 100644 --- a/slice/map_test.go +++ b/slice/map_test.go @@ -101,3 +101,248 @@ func ExampleFilterMap() { fmt.Println(dst) // Output: [1 3] } + +func TestToMapV(t *testing.T) { + t.Run("integer-string to map[int]int", func(t *testing.T) { + elements := []string{"1", "2", "3", "4", "5"} + resMap := ToMapV(elements, func(str string) (int, int) { + num, _ := strconv.Atoi(str) + return num, num + }) + epectedMap := map[int]int{ + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + } + assert.Equal(t, epectedMap, resMap) + }) + t.Run("struct to map[string]struct", func(t *testing.T) { + type eleType struct { + A string + B string + C int + } + elements := []eleType{ + { + A: "a", + B: "b", + C: 1, + }, + { + A: "c", + B: "d", + C: 2, + }, + } + resMap := ToMapV(elements, func(ele eleType) (string, eleType) { + return ele.A, ele + }) + epectedMap := map[string]eleType{ + "a": { + A: "a", + B: "b", + C: 1, + }, + "c": { + A: "c", + B: "d", + C: 2, + }, + } + assert.Equal(t, epectedMap, resMap) + }) + + t.Run("struct to map[string]struct, 重复的key", func(t *testing.T) { + type eleType struct { + A string + B string + C int + } + elements := []eleType{ + { + A: "a", + B: "b", + C: 1, + }, + { + A: "c", + B: "d", + C: 2, + }, + { + A: "a", + B: "d", + C: 3, + }, + } + resMap := ToMapV(elements, func(ele eleType) (string, eleType) { + return ele.A, ele + }) + epectedMap := map[string]eleType{ + "a": { + A: "a", + B: "d", + C: 3, + }, + "c": { + A: "c", + B: "d", + C: 2, + }, + } + assert.Equal(t, epectedMap, resMap) + }) + + t.Run("传入nil slice,返回空map", func(t *testing.T) { + var elements []string = nil + resMap := ToMapV(elements, func(str string) (int, int) { + num, _ := strconv.Atoi(str) + return num, num + }) + epectedMap := make(map[int]int) + assert.Equal(t, epectedMap, resMap) + }) +} + +func TestToMap(t *testing.T) { + t.Run("integer-string to map[int]string", func(t *testing.T) { + elements := []string{"1", "2", "3", "4", "5"} + resMap := ToMap(elements, func(str string) int { + num, _ := strconv.Atoi(str) + return num + }) + epectedMap := map[int]string{ + 1: "1", + 2: "2", + 3: "3", + 4: "4", + 5: "5", + } + assert.Equal(t, epectedMap, resMap) + }) + t.Run("struct to map[string]struct", func(t *testing.T) { + type eleType struct { + A string + B string + C int + } + elements := []eleType{ + { + A: "a", + B: "b", + C: 1, + }, + { + A: "c", + B: "d", + C: 2, + }, + } + resMap := ToMap(elements, func(ele eleType) string { + return ele.A + }) + epectedMap := map[string]eleType{ + "a": { + A: "a", + B: "b", + C: 1, + }, + "c": { + A: "c", + B: "d", + C: 2, + }, + } + assert.Equal(t, epectedMap, resMap) + }) + + t.Run("struct to map[string]struct, 重复的key", func(t *testing.T) { + type eleType struct { + A string + B string + C int + } + elements := []eleType{ + { + A: "a", + B: "b", + C: 1, + }, + { + A: "c", + B: "d", + C: 2, + }, + } + resMap := ToMap(elements, func(ele eleType) string { + return ele.A + }) + epectedMap := map[string]eleType{ + "a": { + A: "a", + B: "b", + C: 1, + }, + "c": { + A: "c", + B: "d", + C: 2, + }, + } + assert.Equal(t, epectedMap, resMap) + }) + + t.Run("传入nil slice,返回空map", func(t *testing.T) { + var elements []string = nil + resMap := ToMap(elements, func(str string) int { + num, _ := strconv.Atoi(str) + return num + }) + epectedMap := make(map[int]string) + assert.Equal(t, epectedMap, resMap) + }) +} + +func ExampleToMap() { + elements := []string{"1", "2", "3", "4", "5"} + resMap := ToMap(elements, func(str string) int { + num, _ := strconv.Atoi(str) + return num + }) + fmt.Println(resMap) + // Output: map[1:1 2:2 3:3 4:4 5:5] +} + +func ExampleToMapV() { + type eleType struct { + A string + B string + C int + } + type eleTypeOut struct { + A string + B string + } + elements := []eleType{ + { + A: "a", + B: "b", + C: 1, + }, + { + A: "c", + B: "d", + C: 2, + }, + } + resMap := ToMapV(elements, func(ele eleType) (string, eleTypeOut) { + return ele.A, eleTypeOut{ + A: ele.A, + B: ele.B, + } + }) + fmt.Println(resMap) + // Output: map[a:{a b} c:{c d}] +} diff --git a/slice/reverse.go b/slice/reverse.go index 01d418c2..14a47236 100644 --- a/slice/reverse.go +++ b/slice/reverse.go @@ -15,7 +15,7 @@ package slice // Reverse 将会完全创建一个新的切片,而不是直接在 src 上进行翻转。 -func Reverse[T comparable](src []T) []T { +func Reverse[T any](src []T) []T { var ret = make([]T, 0, len(src)) for i := len(src) - 1; i >= 0; i-- { ret = append(ret, src[i]) @@ -24,7 +24,7 @@ func Reverse[T comparable](src []T) []T { } // ReverseSelf 會直接在 src 上进行翻转。 -func ReverseSelf[T comparable](src []T) { +func ReverseSelf[T any](src []T) { for i, j := 0, len(src)-1; i < j; i, j = i+1, j-1 { src[i], src[j] = src[j], src[i] } diff --git a/slice/reverse_test.go b/slice/reverse_test.go index a422ebd7..12abf45f 100644 --- a/slice/reverse_test.go +++ b/slice/reverse_test.go @@ -47,6 +47,42 @@ func TestReverseInt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { res := Reverse[int](tt.src) assert.ElementsMatch(t, tt.want, res) + assert.NotSame(t, tt.src, res) + }) + } +} + +func TestReverseStruct(t *testing.T) { + type testStruct struct { + A int + B []int + } + tests := []struct { + name string + src []testStruct + want []testStruct + }{ + { + want: []testStruct{{1, []int{1, 2, 3}}, {3, []int{4, 5, 6}}, {5, []int{7, 8, 9}}}, + src: []testStruct{{5, []int{7, 8, 9}}, {3, []int{4, 5, 6}}, {1, []int{1, 2, 3}}}, + name: "normal test", + }, + { + src: []testStruct{}, + want: []testStruct{}, + name: "length of src is 0", + }, + { + src: nil, + want: []testStruct{}, + name: "length of src is nil", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res := Reverse[testStruct](tt.src) + assert.ElementsMatch(t, tt.want, res) + assert.NotSame(t, tt.src, res) }) } } @@ -81,6 +117,40 @@ func TestReverseSelfInt(t *testing.T) { } } +func TestReverseSelfStruct(t *testing.T) { + type testStruct struct { + A int + B []int + } + tests := []struct { + name string + src []testStruct + want []testStruct + }{ + { + want: []testStruct{{1, []int{1, 2, 3}}, {3, []int{4, 5, 6}}, {5, []int{7, 8, 9}}}, + src: []testStruct{{5, []int{7, 8, 9}}, {3, []int{4, 5, 6}}, {1, []int{1, 2, 3}}}, + name: "normal test", + }, + { + src: []testStruct{}, + want: []testStruct{}, + name: "length of src is 0", + }, + { + src: nil, + want: []testStruct{}, + name: "length of src is nil", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ReverseSelf[testStruct](tt.src) + assert.ElementsMatch(t, tt.want, tt.src) + }) + } +} + func ExampleReverse() { res := Reverse[int]([]int{1, 3, 2, 2, 4}) fmt.Println(res) diff --git a/sqlx/newnull.go b/sqlx/newnull.go new file mode 100644 index 00000000..3aebec98 --- /dev/null +++ b/sqlx/newnull.go @@ -0,0 +1,46 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import ( + "database/sql" + "time" +) + +//这一个系列的方法,会在数据为零值时,将valid设置为false;否则设置为true; + +func NewNullString(val string) sql.NullString { + return sql.NullString{String: val, Valid: val != ""} +} + +func NewNullInt64(val int64) sql.NullInt64 { + return sql.NullInt64{Int64: val, Valid: val != 0} +} + +func NewNullFloat64(val float64) sql.NullFloat64 { + return sql.NullFloat64{Float64: val, Valid: val != 0} +} + +func NewNullBool(val bool) sql.NullBool { + return sql.NullBool{Bool: val, Valid: val} +} + +func NewNullTime(val time.Time) sql.NullTime { + return sql.NullTime{Time: val, Valid: !val.IsZero()} +} + +func NewNullBytes(val []byte) sql.NullString { + return sql.NullString{String: string(val), Valid: len(val) > 0} +} diff --git a/sqlx/newnull_test.go b/sqlx/newnull_test.go new file mode 100644 index 00000000..ed19457b --- /dev/null +++ b/sqlx/newnull_test.go @@ -0,0 +1,203 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlx + +import ( + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewNullBool(t *testing.T) { + tests := []struct { + name string + val bool + want sql.NullBool + }{ + { + name: "nonzero", + val: true, + want: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + { + name: "zero", + val: false, + want: sql.NullBool{ + Bool: false, + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullBool(tt.val), "NewNullBool(%v)", tt.val) + }) + } +} + +func TestNewNullBytes(t *testing.T) { + tests := []struct { + name string + val []byte + want sql.NullString + }{ + { + name: "nonzero", + val: []byte("test"), + want: sql.NullString{ + String: "test", + Valid: true, + }, + }, + { + name: "zero", + val: []byte{}, + want: sql.NullString{ + String: "", + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullBytes(tt.val), "NewNullBytes(%v)", tt.val) + }) + } +} + +func TestNewNullFloat64(t *testing.T) { + tests := []struct { + name string + val float64 + want sql.NullFloat64 + }{ + { + name: "nonzero", + val: 1.1, + want: sql.NullFloat64{ + Float64: 1.1, + Valid: true, + }, + }, + { + name: "zero", + val: 0, + want: sql.NullFloat64{ + Float64: 0, + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullFloat64(tt.val), "NewNullFloat64(%v)", tt.val) + }) + } +} + +func TestNewNullInt64(t *testing.T) { + tests := []struct { + name string + val int64 + want sql.NullInt64 + }{ + { + name: "nonzero", + val: 1, + want: sql.NullInt64{ + Int64: 1, + Valid: true, + }, + }, + { + name: "zero", + val: 0, + want: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullInt64(tt.val), "NewNullInt64(%v)", tt.val) + }) + } +} + +func TestNewNullString(t *testing.T) { + tests := []struct { + name string + val string + want sql.NullString + }{ + { + name: "nonzero", + val: "test", + want: sql.NullString{ + String: "test", + Valid: true, + }, + }, + { + name: "zero", + val: "", + want: sql.NullString{ + String: "", + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullString(tt.val), "NewNullString(%v)", tt.val) + }) + } +} + +func TestNewNullTime(t *testing.T) { + tests := []struct { + name string + val time.Time + want sql.NullTime + }{ + { + name: "nonzero", + val: time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC), + want: sql.NullTime{ + Time: time.Date(2023, 10, 1, 12, 0, 0, 0, time.UTC), + Valid: true, + }, + }, + { + name: "zero", + val: time.Time{}, + want: sql.NullTime{ + Time: time.Time{}, + Valid: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NewNullTime(tt.val), "NewNullTime(%v)", tt.val) + }) + } +} diff --git a/sqlx/scanner.go b/sqlx/scanner.go index 5b1b296e..e39067a3 100644 --- a/sqlx/scanner.go +++ b/sqlx/scanner.go @@ -15,6 +15,8 @@ package sqlx import ( + "bytes" + "database/sql" "errors" "fmt" "reflect" @@ -54,6 +56,7 @@ func NewSQLRowsScanner(r Rows) (Scanner, error) { for i, columnType := range columnTypes { typ := columnType.ScanType() for typ.Kind() == reflect.Pointer { + // 兼容 sqlite,理论上来说其他 driver 不应该命中这个分支 typ = typ.Elem() } columnValuePointers[i] = reflect.New(typ).Interface() @@ -84,7 +87,12 @@ func (s *sqlRowsScanner) Scan() ([]any, error) { func (s *sqlRowsScanner) columnValues() []any { values := make([]any, len(s.columnValuePointers)) for i := 0; i < len(s.columnValuePointers); i++ { - values[i] = reflect.ValueOf(s.columnValuePointers[i]).Elem().Interface() + val := reflect.ValueOf(s.columnValuePointers[i]).Elem().Interface() + // sql.RawBytes 存在内存共享的问题,所以需要执行复制 + if rawBytes, ok := val.(sql.RawBytes); ok { + val = sql.RawBytes(bytes.Clone(rawBytes)) + } + values[i] = val } return values } diff --git a/stringx/string_example_test.go b/stringx/string_example_test.go new file mode 100644 index 00000000..80ea91cc --- /dev/null +++ b/stringx/string_example_test.go @@ -0,0 +1,36 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stringx_test + +import ( + "fmt" + + "github.com/ecodeclub/ekit/stringx" +) + +func ExampleUnsafeToBytes() { + str := "hello" + val := stringx.UnsafeToBytes(str) + fmt.Println(len(val)) + // Output: + // 5 +} + +func ExampleUnsafeToString() { + val := stringx.UnsafeToString([]byte("hello")) + fmt.Println(val) + // Output: + // hello +} diff --git a/syncx/limit_pool.go b/syncx/limit_pool.go new file mode 100644 index 00000000..b7427f74 --- /dev/null +++ b/syncx/limit_pool.go @@ -0,0 +1,53 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "sync/atomic" +) + +// LimitPool 是对 Pool 的简单封装允许用户通过控制一段时间内对Pool的令牌申请次数来间接控制Pool中对象的内存总占用量 +type LimitPool[T any] struct { + pool *Pool[T] + tokens *atomic.Int32 +} + +// NewLimitPool 创建一个 LimitPool 实例 +// maxTokens 表示一段时间内的允许发放的最大令牌数 +// factory 必须返回 T 类型的值,并且不能返回 nil +func NewLimitPool[T any](maxTokens int, factory func() T) *LimitPool[T] { + var tokens atomic.Int32 + tokens.Add(int32(maxTokens)) + return &LimitPool[T]{ + pool: NewPool[T](factory), + tokens: &tokens, + } +} + +// Get 取出一个元素 +func (l *LimitPool[T]) Get() T { + if l.tokens.Add(-1) < 0 { + l.tokens.Add(1) + var zero T + return zero + } + return l.pool.Get() +} + +// Put 放回去一个元素 +func (l *LimitPool[T]) Put(t T) { + l.pool.Put(t) + l.tokens.Add(1) +} diff --git a/syncx/limit_pool_test.go b/syncx/limit_pool_test.go new file mode 100644 index 00000000..aad405c8 --- /dev/null +++ b/syncx/limit_pool_test.go @@ -0,0 +1,68 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "bytes" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLimitPool(t *testing.T) { + + expectedMaxAttempts := 3 + expectedVal := []byte("A") + + pool := NewLimitPool(expectedMaxAttempts, func() []byte { + var buffer bytes.Buffer + buffer.Write(expectedVal) + return buffer.Bytes() + }) + + var wg sync.WaitGroup + bufChan := make(chan []byte, expectedMaxAttempts) + + // 从Pool中并发获取缓冲区 + for i := 0; i < expectedMaxAttempts; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + buf := pool.Get() + + assert.NotZero(t, buf) + assert.Equal(t, string(expectedVal), string(buf)) + + bufChan <- buf + }() + } + + wg.Wait() + close(bufChan) + + // 超过最大申请次数返回零值 + assert.Zero(t, pool.Get()) + + // 归还一个 + pool.Put(<-bufChan) + + // 再次申请仍可以拿到非零值缓冲区 + assert.NotZero(t, string(expectedVal), string(pool.Get())) + + // 超过最大申请次数返回零值 + assert.Zero(t, pool.Get()) +} diff --git a/syncx/segment_key_lock.go b/syncx/segment_key_lock.go new file mode 100644 index 00000000..59f5de38 --- /dev/null +++ b/syncx/segment_key_lock.go @@ -0,0 +1,80 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "hash/fnv" + "sync" +) + +// SegmentKeysLock 部分key lock结构定义 +type SegmentKeysLock struct { + locks []*sync.RWMutex + size uint32 +} + +// NewSegmentKeysLock 创建 SegmentKeysLock 示例 +func NewSegmentKeysLock(size uint32) *SegmentKeysLock { + locks := make([]*sync.RWMutex, size) + for i := range locks { + locks[i] = &sync.RWMutex{} + } + return &SegmentKeysLock{ + locks: locks, + size: size, + } +} + +// hash 索引锁的hash函数 +func (s *SegmentKeysLock) hash(key string) uint32 { + h := fnv.New32a() + _, _ = h.Write([]byte(key)) + return h.Sum32() +} + +// RLock 读锁加锁 +func (s *SegmentKeysLock) RLock(key string) { + s.getLock(key).RLock() +} + +// TryRLock 试着加读锁,加锁成功会返回 +func (s *SegmentKeysLock) TryRLock(key string) bool { + return s.getLock(key).TryRLock() +} + +// RUnlock 读锁解锁 +func (s *SegmentKeysLock) RUnlock(key string) { + s.getLock(key).RUnlock() +} + +// Lock 写锁加锁 +func (s *SegmentKeysLock) Lock(key string) { + s.getLock(key).Lock() +} + +// TryLock 试着加锁,加锁成功会返回 true +func (s *SegmentKeysLock) TryLock(key string) bool { + return s.getLock(key).TryLock() +} + +// Unlock 写锁解锁 +func (s *SegmentKeysLock) Unlock(key string) { + s.getLock(key).Unlock() +} + +func (s *SegmentKeysLock) getLock(key string) *sync.RWMutex { + hash := s.hash(key) + return s.locks[hash%s.size] +} diff --git a/syncx/segment_key_lock_test.go b/syncx/segment_key_lock_test.go new file mode 100644 index 00000000..f927dbc6 --- /dev/null +++ b/syncx/segment_key_lock_test.go @@ -0,0 +1,67 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// 通过 TryLock 和 TryRLock 来判定加锁问题 +// 也就是只判定我们拿到了正确的锁,但是没有判定并发与互斥 + +// TestNewSegmentKeysLock_Lock 测试 Lock, UnLock 和 TryLock +func TestNewSegmentKeysLock_Lock(t *testing.T) { + l := NewSegmentKeysLock(8) + key1 := "key1" + l.Lock(key1) + // 必然加锁失败 + assert.False(t, l.TryLock(key1)) + // 读锁也失败 + assert.False(t, l.TryRLock(key1)) + key2 := "key2" + // 加锁成功 + assert.True(t, l.TryLock(key2)) + // 解锁不会触发 panic + defer l.Unlock(key2) + + // 释放锁 + l.Unlock(key1) + // 此时应该预期自己可以再次加锁 + assert.True(t, l.TryLock(key1)) +} + +func TestNewSegmentKeysLock_RLock(t *testing.T) { + l := NewSegmentKeysLock(8) + key1, key2 := "key1", "key2" + l.RLock(key1) + // 必然加锁失败 + assert.False(t, l.TryLock(key1)) + // 读锁可以成功 + assert.True(t, l.TryRLock(key1)) + // 加锁成功 + assert.True(t, l.TryRLock(key2)) + // 解锁不会触发 panic + defer l.RUnlock(key2) + + // 释放读锁 + l.RUnlock(key1) + // 此时还有一个读锁没有释放 + assert.False(t, l.TryLock(key1)) + // 再次释放读锁 + l.RUnlock(key1) + assert.True(t, l.TryLock(key1)) +} diff --git a/tree/red_black_tree.go b/tree/red_black_tree.go new file mode 100644 index 00000000..fc78c09d --- /dev/null +++ b/tree/red_black_tree.go @@ -0,0 +1,71 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "errors" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/tree" +) + +var ( + errRBTreeComparatorIsNull = errors.New("ekit: RBTree 的 Comparator 不能为 nil") +) + +// RBTree 简单的封装一下红黑树 +type RBTree[K any, V any] struct { + rbTree *tree.RBTree[K, V] //红黑树本体 +} + +func NewRBTree[K any, V any](compare ekit.Comparator[K]) (*RBTree[K, V], error) { + if nil == compare { + return nil, errRBTreeComparatorIsNull + } + + return &RBTree[K, V]{ + rbTree: tree.NewRBTree[K, V](compare), + }, nil +} + +// Add 增加节点 +func (rb *RBTree[K, V]) Add(key K, value V) error { + return rb.rbTree.Add(key, value) +} + +// Delete 删除节点 +func (rb *RBTree[K, V]) Delete(key K) (V, bool) { + return rb.rbTree.Delete(key) +} + +// Set 修改节点 +func (rb *RBTree[K, V]) Set(key K, value V) error { + return rb.rbTree.Set(key, value) +} + +// Find 查找节点 +func (rb *RBTree[K, V]) Find(key K) (V, error) { + return rb.rbTree.Find(key) +} + +// Size 返回红黑树结点个数 +func (rb *RBTree[K, V]) Size() int { + return rb.rbTree.Size() +} + +// KeyValues 获取红黑树所有节点K,V +func (rb *RBTree[K, V]) KeyValues() ([]K, []V) { + return rb.rbTree.KeyValues() +} diff --git a/tree/red_black_tree_test.go b/tree/red_black_tree_test.go new file mode 100644 index 00000000..ee73cccb --- /dev/null +++ b/tree/red_black_tree_test.go @@ -0,0 +1,187 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "testing" + + "github.com/ecodeclub/ekit" + "github.com/ecodeclub/ekit/internal/tree" + "github.com/stretchr/testify/assert" +) + +func compare() ekit.Comparator[int] { + return ekit.ComparatorRealNumber[int] +} + +func TestNewRBTree(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantErr error + }{ + { + name: "compare is nil", + compare: nil, + wantErr: errRBTreeComparatorIsNull, + }, + { + name: "compare is ok", + compare: compare(), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := NewRBTree[int, string](tc.compare) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestRBTree_Add(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + key int + value string + wantErr error + }{ + { + name: "no err is ok", + compare: compare(), + key: 1, + value: "value1", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + err := rbTree.Add(tc.key, tc.value) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestRBTree_Delete(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + key int + wantBool bool + }{ + { + name: "no err is ok", + compare: compare(), + key: 1, + wantBool: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + _, resultBool := rbTree.Delete(tc.key) + assert.Equal(t, tc.wantBool, resultBool) + }) + } +} + +func TestRBTree_Set(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + key int + value string + wantErr error + }{ + { + name: "no err is ok", + compare: compare(), + key: 1, + value: "value1", + wantErr: tree.ErrRBTreeNotRBNode, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + err := rbTree.Set(tc.key, tc.value) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestRBTree_Find(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + key int + wantErr error + }{ + { + name: "no err is ok", + compare: compare(), + key: 1, + wantErr: tree.ErrRBTreeNotRBNode, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + _, err := rbTree.Find(tc.key) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestRBTree_Size(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + wantSize int + }{ + { + name: "no err is ok", + compare: compare(), + wantSize: 0, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + size := rbTree.Size() + assert.Equal(t, tc.wantSize, size) + }) + } +} + +func TestRBTree_KeyValues(t *testing.T) { + testCases := []struct { + name string + compare ekit.Comparator[int] + }{ + { + name: "no err is ok", + compare: compare(), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rbTree, _ := NewRBTree[int, string](tc.compare) + keys, values := rbTree.KeyValues() + assert.Equal(t, 0, len(keys)) + assert.Equal(t, 0, len(values)) + }) + } +} diff --git a/tuple/pair/pair.go b/tuple/pair/pair.go new file mode 100644 index 00000000..b692d54f --- /dev/null +++ b/tuple/pair/pair.go @@ -0,0 +1,132 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pair + +import ( + "fmt" +) + +type Pair[K any, V any] struct { + Key K + Value V +} + +func (pair *Pair[K, V]) String() string { + return fmt.Sprintf("<%#v, %#v>", pair.Key, pair.Value) +} + +// Split 方法将Key, Value作为返回参数传出。 +func (pair *Pair[K, V]) Split() (K, V) { + return pair.Key, pair.Value +} + +func NewPair[K any, V any]( + key K, + value V, +) Pair[K, V] { + return Pair[K, V]{ + Key: key, + Value: value, + } +} + +// NewPairs 需要传入两个长度相同并且均不为nil的数组 keys 和 values, +// 设keys长度为n,返回一个长度为n的pair数组。 +// 保证: +// +// 返回的pair数组满足条件(设pair数组为p): +// 对于所有的 0 <= i < n +// p[i].Key == keys[i] 并且 p[i].Value == values[i] +// +// 如果传入的keys或者values为nil,会返回error +// +// 如果传入的keys长度与values长度不同,会返回error +func NewPairs[K any, V any]( + keys []K, + values []V, +) ([]Pair[K, V], error) { + if keys == nil || values == nil { + return nil, fmt.Errorf("keys与values均不可为nil") + } + n := len(keys) + if n != len(values) { + return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values)) + } + pairs := make([]Pair[K, V], n) + for i := 0; i < n; i++ { + pairs[i] = NewPair(keys[i], values[i]) + } + return pairs, nil +} + +// SplitPairs 需要传入一个[]Pair[K, V],数组可以为nil。 +// 设pairs数组的长度为n,返回两个长度均为n的数组keys, values。 +// 如果pairs数组是nil, 则返回的keys与values也均为nil。 +func SplitPairs[K any, V any](pairs []Pair[K, V]) (keys []K, values []V) { + if pairs == nil { + return nil, nil + } + n := len(pairs) + keys = make([]K, n) + values = make([]V, n) + for i, pair := range pairs { + keys[i], values[i] = pair.Split() + } + return +} + +// FlattenPairs 需要传入一个[]Pair[K, V],数组可以为nil +// 如果pairs数组为nil,则返回的flatPairs数组也为nil +// +// 设pairs数组长度为n,保证返回的flatPairs数组长度为2 * n且满足: +// 对于所有的 0 <= i < n +// flatPairs[i * 2] == pairs[i].Key +// flatPairs[i * 2 + 1] == pairs[i].Value +func FlattenPairs[K any, V any](pairs []Pair[K, V]) (flatPairs []any) { + if pairs == nil { + return nil + } + n := len(pairs) + flatPairs = make([]any, 0, n*2) + for _, pair := range pairs { + flatPairs = append(flatPairs, pair.Key, pair.Value) + } + return +} + +// PackPairs 需要传入一个长度为2 * n的数组flatPairs,数组可以为nil。 +// +// 函数将会返回一个长度为n的pairs数组,pairs满足 +// 对于所有的 0 <= i < n +// pairs[i].Key == flatPairs[i * 2] +// pairs[i].Value == flatPairs[i * 2 + 1] +// 如果flatPairs为nil,则返回的pairs也为nil +// +// 入参flatPairs需要满足以下条件: +// 对于所有的 0 <= i < n +// flatPairs[i * 2] 的类型为 K +// flatPairs[i * 2 + 1] 的类型为 V +// 否则会panic +func PackPairs[K any, V any](flatPairs []any) (pairs []Pair[K, V]) { + if flatPairs == nil { + return nil + } + n := len(flatPairs) / 2 + pairs = make([]Pair[K, V], n) + for i := 0; i < n; i++ { + pairs[i] = NewPair(flatPairs[i*2].(K), flatPairs[i*2+1].(V)) + } + return +} diff --git a/tuple/pair/pair_test.go b/tuple/pair/pair_test.go new file mode 100644 index 00000000..233d08aa --- /dev/null +++ b/tuple/pair/pair_test.go @@ -0,0 +1,237 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pair_test + +import ( + "fmt" + "sort" + "testing" + + "github.com/ecodeclub/ekit/mapx" + "github.com/ecodeclub/ekit/tuple/pair" + "github.com/stretchr/testify/suite" +) + +type testPairSuite struct{ suite.Suite } + +func (s *testPairSuite) TestString() { + { + p := pair.NewPair(100, "23333") + s.Assert().Equal("<100, \"23333\">", p.String()) + } + { + p := pair.NewPair("testStruct", map[int]int{ + 11: 1, + 22: 2, + 33: 3, + }) + s.Assert().Equal("<\"testStruct\", map[int]int{11:1, 22:2, 33:3}>", p.String()) + } +} + +func (s *testPairSuite) TestNewPairs() { + type caseType struct { + // input + keys []int + values []string + // expected + pairs []pair.Pair[int, string] + err error + } + for _, c := range []caseType{ + { + keys: []int{1, 2, 3, 4, 5}, + values: []string{"1", "2", "3", "4", "5"}, + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + err: nil, + }, + { + keys: nil, + values: []string{"1"}, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1}, + values: nil, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: nil, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1, 2}, + values: []string{"1"}, + pairs: nil, + err: fmt.Errorf("keys与values的长度不同, len(keys)=2, len(values)=1"), + }, + } { + pairs, err := pair.NewPairs(c.keys, c.values) + s.Assert().Equal(c.err, err) + s.Assert().EqualValues(c.pairs, pairs) + } +} + +func (s *testPairSuite) TestSplitPairs() { + type caseType struct { + // input + pairs []pair.Pair[int, string] + // expected + keys []int + values []string + } + for _, c := range []caseType{ + { + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + keys: []int{1, 2, 3, 4, 5}, + values: []string{"1", "2", "3", "4", "5"}, + }, + { + pairs: nil, + + keys: nil, + values: nil, + }, + { + pairs: []pair.Pair[int, string]{}, + keys: []int{}, + values: []string{}, + }, + } { + keys, values := pair.SplitPairs(c.pairs) + if c.pairs == nil { + s.Assert().Nil(keys) + s.Assert().Nil(values) + } else { + s.Assert().Len(keys, len(c.pairs)) + s.Assert().Len(values, len(c.pairs)) + for i, pair := range c.pairs { + s.Assert().Equal(pair.Key, keys[i]) + s.Assert().Equal(pair.Value, values[i]) + } + } + } +} + +func (s *testPairSuite) TestFlattenPairs() { + type caseType struct { + pairs []pair.Pair[int, string] + flattPairs []any + } + + for _, c := range []caseType{ + { + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + flattPairs: []any{1, "1", 2, "2", 3, "3", 4, "4", 5, "5"}, + }, + { + pairs: nil, + flattPairs: nil, + }, + { + pairs: []pair.Pair[int, string]{}, + flattPairs: []any{}, + }, + } { + flatPairs := pair.FlattenPairs(c.pairs) + s.Assert().EqualValues(c.flattPairs, flatPairs) + } +} + +func (s *testPairSuite) TestPackPairs() { + type caseType struct { + flattPairs []any + pairs []pair.Pair[int, string] + } + + for _, c := range []caseType{ + { + flattPairs: []any{1, "1", 2, "2", 3, "3", 4, "4", 5, "5"}, + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + }, + { + flattPairs: nil, + pairs: nil, + }, + { + flattPairs: []any{}, + pairs: []pair.Pair[int, string]{}, + }, + } { + pairs := pair.PackPairs[int, string](c.flattPairs) + s.Assert().EqualValues(c.pairs, pairs) + } +} + +func (s *testPairSuite) TestMapPairMapping() { + // map to pairs + expectedMap := map[int]string{ + 1: "1", + 2: "2", + 3: "3", + } + expectedPairs := []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + } + + // 可以用这种方式实现map到[]Pair的映射 + pairs, err := pair.NewPairs(mapx.KeysValues(expectedMap)) + s.Assert().Nil(err) + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].Key < pairs[j].Key + }) + s.Assert().EqualValues(expectedPairs, pairs) + + // 可以用这种方式实现[]Pair到map的映射 + mp, err := mapx.ToMap(pair.SplitPairs(expectedPairs)) + s.Assert().Nil(err) + for k, v := range mp { + s.Assert().Equal(expectedMap[k], v) + } +} + +func TestPair(t *testing.T) { + suite.Run(t, new(testPairSuite)) +} diff --git a/value.go b/value.go index a2e89bb4..66108e12 100644 --- a/value.go +++ b/value.go @@ -15,8 +15,8 @@ package ekit import ( + "encoding/json" "errors" - "fmt" "reflect" "strconv" @@ -36,7 +36,7 @@ func (av AnyValue) Int() (int, error) { } val, ok := av.Val.(int) if !ok { - return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int", av.Val) } return val, nil } @@ -52,7 +52,7 @@ func (av AnyValue) AsInt() (int, error) { res, err := strconv.ParseInt(v, 10, 64) return int(res), err } - return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int", av.Val) } // IntOrDefault 返回 int 数据,或者默认值 @@ -71,7 +71,7 @@ func (av AnyValue) Uint() (uint, error) { } val, ok := av.Val.(uint) if !ok { - return 0, errs.NewErrInvalidType("uint", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint", av.Val) } return val, nil } @@ -87,7 +87,7 @@ func (av AnyValue) AsUint() (uint, error) { res, err := strconv.ParseUint(v, 10, 64) return uint(res), err } - return 0, errs.NewErrInvalidType("uint", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint", av.Val) } // UintOrDefault 返回 uint 数据,或者默认值 @@ -105,7 +105,7 @@ func (av AnyValue) Int8() (int8, error) { } val, ok := av.Val.(int8) if !ok { - return 0, errs.NewErrInvalidType("int", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int", av.Val) } return val, nil } @@ -122,7 +122,7 @@ func (av AnyValue) AsInt8() (int8, error) { res, err := strconv.ParseInt(v, 10, 64) return int8(res), err } - return 0, errs.NewErrInvalidType("int8", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int8", av.Val) } func (av AnyValue) Int8OrDefault(def int8) int8 { @@ -139,7 +139,7 @@ func (av AnyValue) Uint8() (uint8, error) { } val, ok := av.Val.(uint8) if !ok { - return 0, errs.NewErrInvalidType("uint8", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint8", av.Val) } return val, nil } @@ -156,7 +156,7 @@ func (av AnyValue) AsUint8() (uint8, error) { res, err := strconv.ParseUint(v, 10, 8) return uint8(res), err } - return 0, errs.NewErrInvalidType("uint8", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint8", av.Val) } func (av AnyValue) Uint8OrDefault(def uint8) uint8 { @@ -173,7 +173,7 @@ func (av AnyValue) Int16() (int16, error) { } val, ok := av.Val.(int16) if !ok { - return 0, errs.NewErrInvalidType("int16", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int16", av.Val) } return val, nil } @@ -190,7 +190,7 @@ func (av AnyValue) AsInt16() (int16, error) { res, err := strconv.ParseInt(v, 10, 16) return int16(res), err } - return 0, errs.NewErrInvalidType("int16", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int16", av.Val) } func (av AnyValue) Int16OrDefault(def int16) int16 { @@ -207,7 +207,7 @@ func (av AnyValue) Uint16() (uint16, error) { } val, ok := av.Val.(uint16) if !ok { - return 0, errs.NewErrInvalidType("uint16", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint16", av.Val) } return val, nil } @@ -224,7 +224,7 @@ func (av AnyValue) AsUint16() (uint16, error) { res, err := strconv.ParseUint(v, 10, 16) return uint16(res), err } - return 0, errs.NewErrInvalidType("uint16", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint16", av.Val) } func (av AnyValue) Uint16OrDefault(def uint16) uint16 { @@ -242,7 +242,7 @@ func (av AnyValue) Int32() (int32, error) { } val, ok := av.Val.(int32) if !ok { - return 0, errs.NewErrInvalidType("int32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int32", av.Val) } return val, nil } @@ -258,7 +258,7 @@ func (av AnyValue) AsInt32() (int32, error) { res, err := strconv.ParseInt(v, 10, 32) return int32(res), err } - return 0, errs.NewErrInvalidType("int32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int32", av.Val) } // Int32OrDefault 返回 int32 数据,或者默认值 @@ -277,7 +277,7 @@ func (av AnyValue) Uint32() (uint32, error) { } val, ok := av.Val.(uint32) if !ok { - return 0, errs.NewErrInvalidType("uint32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint32", av.Val) } return val, nil } @@ -293,7 +293,7 @@ func (av AnyValue) AsUint32() (uint32, error) { res, err := strconv.ParseUint(v, 10, 32) return uint32(res), err } - return 0, errs.NewErrInvalidType("uint32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint32", av.Val) } // Uint32OrDefault 返回 uint32 数据,或者默认值 @@ -312,7 +312,7 @@ func (av AnyValue) Int64() (int64, error) { } val, ok := av.Val.(int64) if !ok { - return 0, errs.NewErrInvalidType("int64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int64", av.Val) } return val, nil } @@ -327,7 +327,7 @@ func (av AnyValue) AsInt64() (int64, error) { case string: return strconv.ParseInt(v, 10, 64) } - return 0, errs.NewErrInvalidType("int64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("int64", av.Val) } // Int64OrDefault 返回 int64 数据,或者默认值 @@ -346,7 +346,7 @@ func (av AnyValue) Uint64() (uint64, error) { } val, ok := av.Val.(uint64) if !ok { - return 0, errs.NewErrInvalidType("uint64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint64", av.Val) } return val, nil } @@ -361,7 +361,7 @@ func (av AnyValue) AsUint64() (uint64, error) { case string: return strconv.ParseUint(v, 10, 64) } - return 0, errs.NewErrInvalidType("uint64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("uint64", av.Val) } // Uint64OrDefault 返回 uint64 数据,或者默认值 @@ -380,7 +380,7 @@ func (av AnyValue) Float32() (float32, error) { } val, ok := av.Val.(float32) if !ok { - return 0, errs.NewErrInvalidType("float32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("float32", av.Val) } return val, nil } @@ -396,7 +396,7 @@ func (av AnyValue) AsFloat32() (float32, error) { res, err := strconv.ParseFloat(v, 32) return float32(res), err } - return 0, errs.NewErrInvalidType("float32", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("float32", av.Val) } // Float32OrDefault 返回 float32 数据,或者默认值 @@ -415,7 +415,7 @@ func (av AnyValue) Float64() (float64, error) { } val, ok := av.Val.(float64) if !ok { - return 0, errs.NewErrInvalidType("float64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("float64", av.Val) } return val, nil } @@ -430,7 +430,7 @@ func (av AnyValue) AsFloat64() (float64, error) { case string: return strconv.ParseFloat(v, 64) } - return 0, errs.NewErrInvalidType("float64", reflect.TypeOf(av.Val).String()) + return 0, errs.NewErrInvalidType("float64", av.Val) } // Float64OrDefault 返回 float64 数据,或者默认值 @@ -449,7 +449,7 @@ func (av AnyValue) String() (string, error) { } val, ok := av.Val.(string) if !ok { - return "", errs.NewErrInvalidType("string", reflect.TypeOf(av.Val).String()) + return "", errs.NewErrInvalidType("string", av.Val) } return val, nil } @@ -474,7 +474,7 @@ func (av AnyValue) AsString() (string, error) { val = strconv.FormatFloat(valueOf.Float(), 'f', 10, 64) case reflect.Slice: if valueOf.Type().Elem().Kind() != reflect.Uint8 { - return "", errs.NewErrInvalidType("[]byte", fmt.Sprintf("[]%s", valueOf.Type().Elem().Kind())) + return "", errs.NewErrInvalidType("[]byte", av.Val) } val = string(valueOf.Bytes()) default: @@ -500,7 +500,7 @@ func (av AnyValue) Bytes() ([]byte, error) { } val, ok := av.Val.([]byte) if !ok { - return nil, errs.NewErrInvalidType("[]byte", reflect.TypeOf(av.Val).String()) + return nil, errs.NewErrInvalidType("[]byte", av.Val) } return val, nil } @@ -516,7 +516,7 @@ func (av AnyValue) AsBytes() ([]byte, error) { return []byte(v), nil } - return []byte{}, errs.NewErrInvalidType("[]byte", reflect.TypeOf(av.Val).String()) + return []byte{}, errs.NewErrInvalidType("[]byte", av.Val) } // BytesOrDefault 返回 []byte 数据,或者默认值 @@ -535,7 +535,7 @@ func (av AnyValue) Bool() (bool, error) { } val, ok := av.Val.(bool) if !ok { - return false, errs.NewErrInvalidType("bool", reflect.TypeOf(av.Val).String()) + return false, errs.NewErrInvalidType("bool", av.Val) } return val, nil } @@ -548,3 +548,12 @@ func (av AnyValue) BoolOrDefault(def bool) bool { } return val } + +// JSONScan 将 val 转化为一个对象 +func (av AnyValue) JSONScan(val any) error { + data, err := av.AsBytes() + if err != nil { + return err + } + return json.Unmarshal(data, val) +} diff --git a/value_test.go b/value_test.go index 8e9b9002..46cec5a2 100644 --- a/value_test.go +++ b/value_test.go @@ -16,7 +16,6 @@ package ekit import ( "errors" - "reflect" "testing" "github.com/ecodeclub/ekit/internal/errs" @@ -50,7 +49,7 @@ func TestAnyValue_Int(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("int", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("int", ""), }, } for _, tt := range tests { @@ -136,7 +135,7 @@ func TestAnyValue_Uint(t *testing.T) { val: AnyValue{ Val: []string{"111"}, }, - err: errs.NewErrInvalidType("uint", reflect.TypeOf([]string{"111"}).String()), + err: errs.NewErrInvalidType("uint", []string{"111"}), }, } for _, tt := range tests { @@ -222,7 +221,7 @@ func TestAnyValue_Int32(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("int32", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("int32", ""), }, } for _, tt := range tests { @@ -308,7 +307,7 @@ func TestAnyValue_Uint32(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("uint32", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("uint32", ""), }, } for _, tt := range tests { @@ -395,7 +394,7 @@ func TestAnyValue_Int64(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("int64", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("int64", ""), }, } for _, tt := range tests { @@ -481,7 +480,7 @@ func TestAnyValue_Uint64(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("uint64", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("uint64", ""), }, } for _, tt := range tests { @@ -567,7 +566,7 @@ func TestAnyValue_Float32(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("float32", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("float32", ""), }, } for _, tt := range tests { @@ -654,7 +653,7 @@ func TestAnyValue_Float64(t *testing.T) { val: AnyValue{ Val: "", }, - err: errs.NewErrInvalidType("float64", reflect.TypeOf("").String()), + err: errs.NewErrInvalidType("float64", ""), }, } for _, tt := range tests { @@ -740,7 +739,7 @@ func TestAnyValue_String(t *testing.T) { val: AnyValue{ Val: 1, }, - err: errs.NewErrInvalidType("string", reflect.TypeOf(111).String()), + err: errs.NewErrInvalidType("string", 1), }, } for _, tt := range tests { @@ -826,7 +825,7 @@ func TestAnyValue_Bytes(t *testing.T) { val: AnyValue{ Val: 1, }, - err: errs.NewErrInvalidType("[]byte", reflect.TypeOf(111).String()), + err: errs.NewErrInvalidType("[]byte", 1), }, } for _, tt := range tests { @@ -912,7 +911,7 @@ func TestAnyValue_Bool(t *testing.T) { val: AnyValue{ Val: 1, }, - err: errs.NewErrInvalidType("bool", reflect.TypeOf(1).String()), + err: errs.NewErrInvalidType("bool", 1), }, } for _, tt := range tests { @@ -1169,7 +1168,7 @@ func TestAnyValue_AsInt(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("int", "[]int"), + err: errs.NewErrInvalidType("int", []int{1}), }, { name: "value exists error case:", @@ -1216,7 +1215,7 @@ func TestAnyValue_AsInt8(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("int8", "[]int"), + err: errs.NewErrInvalidType("int8", []int{1}), }, { name: "value exists error case:", @@ -1263,7 +1262,7 @@ func TestAnyValue_AsInt16(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("int16", "[]int"), + err: errs.NewErrInvalidType("int16", []int{1}), }, { name: "value exists error case:", @@ -1310,7 +1309,7 @@ func TestAnyValue_AsInt32(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("int32", "[]int"), + err: errs.NewErrInvalidType("int32", []int{1}), }, { name: "value exists error case:", @@ -1357,7 +1356,7 @@ func TestAnyValue_AsInt64(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("int64", "[]int"), + err: errs.NewErrInvalidType("int64", []int{1}), }, { name: "value exists error case:", @@ -1404,7 +1403,7 @@ func TestAnyValue_AsUint(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("uint", "[]int"), + err: errs.NewErrInvalidType("uint", []int{1}), }, { name: "value exists error case:", @@ -1451,7 +1450,7 @@ func TestAnyValue_AsUint8(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("uint8", "[]int"), + err: errs.NewErrInvalidType("uint8", []int{1}), }, { name: "value exists error case:", @@ -1498,7 +1497,7 @@ func TestAnyValue_AsUint16(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("uint16", "[]int"), + err: errs.NewErrInvalidType("uint16", []int{1}), }, { name: "value exists error case:", @@ -1545,7 +1544,7 @@ func TestAnyValue_AsUint32(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("uint32", "[]int"), + err: errs.NewErrInvalidType("uint32", []int{1}), }, { name: "value exists error case:", @@ -1592,7 +1591,7 @@ func TestAnyValue_AsUint64(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("uint64", "[]int"), + err: errs.NewErrInvalidType("uint64", []int{1}), }, { name: "value exists error case:", @@ -1639,7 +1638,7 @@ func TestAnyValue_AsFloat32(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("float32", "[]int"), + err: errs.NewErrInvalidType("float32", []int{1}), }, { name: "value exists error case:", @@ -1685,7 +1684,7 @@ func TestAnyValue_AsFloat64(t *testing.T) { val: AnyValue{ Val: []int{1}, }, - err: errs.NewErrInvalidType("float64", "[]int"), + err: errs.NewErrInvalidType("float64", []int{1}), }, { name: "value exists error case:", @@ -1732,7 +1731,7 @@ func TestAnyValue_AsBytes(t *testing.T) { Val: []int{1}, }, want: []byte{}, - err: errs.NewErrInvalidType("[]byte", "[]int"), + err: errs.NewErrInvalidType("[]byte", []int{1}), }, { name: "value exists error case:", @@ -1800,14 +1799,14 @@ func TestAnyValue_AsString(t *testing.T) { val: AnyValue{ Val: []string{"h", "e", "llo"}, }, - err: errs.NewErrInvalidType("[]byte", "[]string"), + err: errs.NewErrInvalidType("[]byte", []string{"h", "e", "llo"}), }, { name: "type conversion failed by int", val: AnyValue{ Val: []int{1, 2, 3, 4, 5}, }, - err: errs.NewErrInvalidType("[]byte", "[]int"), + err: errs.NewErrInvalidType("[]byte", []int{1, 2, 3, 4, 5}), }, { name: "unsupported type case:", @@ -1829,3 +1828,45 @@ func TestAnyValue_AsString(t *testing.T) { }) } } + +func TestAnyValue_JSONScan(t *testing.T) { + testCases := []struct { + name string + + av AnyValue + + wantUser User + wantErr error + }{ + { + name: "OK", + av: AnyValue{ + Val: `{"name": "Tom"}`, + }, + wantUser: User{ + Name: "Tom", + }, + }, + + { + name: "error", + av: AnyValue{ + Err: errors.New("mock error"), + }, + wantErr: errors.New("mock error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var u User + err := tc.av.JSONScan(&u) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantUser, u) + }) + } +} + +type User struct { + Name string `json:"name"` +}