Skip to content

Commit

Permalink
feat(merger): 新增Distinct Merger (#224)
Browse files Browse the repository at this point in the history
Co-authored-by: juniaoshaonian <1633720889@qq.com>
  • Loading branch information
longyue0521 and juniaoshaonian authored Jun 14, 2024
1 parent 16c03e1 commit 88085be
Show file tree
Hide file tree
Showing 15 changed files with 3,829 additions and 899 deletions.
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
- [doc: 修复README中不可用的贡献者指南链接](https://github.com/ecodeclub/eorm/pull/221)
- [feat(merger): 定义中立的特征表达数据、定义工厂方法根据特征数据来获取具体的merger](https://github.com/ecodeclub/eorm/pull/222)
- [refactor(merger): 重构AVG函数实现,重构所有rows.Rows实现的ConlumnType方法并添加测试](https://github.com/ecodeclub/eorm/pull/223)
- [feat(merger): 新增Distinct Merger](https://github.com/ecodeclub/eorm/pull/224)
## v0.0.1:
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)
- [Selector Definition](https://github.com/ecodeclub/eorm/pull/2)
Expand Down
91 changes: 66 additions & 25 deletions internal/merger/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ import (
"context"
"errors"
"fmt"
"log"
"strings"

"github.com/ecodeclub/ekit/slice"
"github.com/ecodeclub/eorm/internal/merger"
"github.com/ecodeclub/eorm/internal/merger/internal/aggregatemerger"
"github.com/ecodeclub/eorm/internal/merger/internal/aggregatemerger/aggregator"
"github.com/ecodeclub/eorm/internal/merger/internal/batchmerger"
"github.com/ecodeclub/eorm/internal/merger/internal/distinctmerger"
"github.com/ecodeclub/eorm/internal/merger/internal/groupbymerger"
"github.com/ecodeclub/eorm/internal/merger/internal/pagedmerger"
"github.com/ecodeclub/eorm/internal/merger/internal/sortmerger"
Expand All @@ -39,6 +39,7 @@ var (
ErrColumnNotFoundInSelectList = errors.New("factory: Select列表中未找到列")
ErrInvalidLimit = errors.New("factory: Limit小于1")
ErrInvalidOffset = errors.New("factory: Offset不等于0")
ErrInvalidFeatures = errors.New("factory: Features非法")
)

type (
Expand All @@ -62,6 +63,10 @@ type (

func (q QuerySpec) Validate() error {

if err := q.validateFeatures(); err != nil {
return err
}

if err := q.validateSelect(); err != nil {
return err
}
Expand All @@ -70,6 +75,10 @@ func (q QuerySpec) Validate() error {
return err
}

if err := q.validateDistinct(); err != nil {
return err
}

if err := q.validateOrderBy(); err != nil {
return err
}
Expand All @@ -81,6 +90,24 @@ func (q QuerySpec) Validate() error {
return nil
}

func (q QuerySpec) validateFeatures() error {
for i, v := range q.Features {
if i == 0 {
continue
}
if v < q.Features[i-1] {
return fmt.Errorf("%w: 顺序错误", ErrInvalidFeatures)
}
}
if slice.Contains(q.Features, query.AggregateFunc) && slice.Contains(q.Features, query.GroupBy) {
return fmt.Errorf("%w: 聚合特征与GroupBy不该同时出现", ErrInvalidFeatures)
}
if slice.Contains(q.Features, query.GroupBy) && slice.Contains(q.Features, query.Distinct) {
return fmt.Errorf("%w: GroupBy与DISTINCT不该同时出现", ErrInvalidFeatures)
}
return nil
}

func (q QuerySpec) validateSelect() error {
if len(q.Select) == 0 {
return fmt.Errorf("%w: select", ErrEmptyColumnList)
Expand All @@ -105,7 +132,7 @@ func (q QuerySpec) validateGroupBy() error {
return fmt.Errorf("%w: groupby %v", ErrInvalidColumnInfo, c.Name)
}
// 清除ASC
c.Order = merger.DESC
c.Order = merger.OrderDESC
if !slice.Contains(q.Select, c) {
return fmt.Errorf("%w: groupby %v", ErrColumnNotFoundInSelectList, c.Name)
}
Expand All @@ -121,6 +148,20 @@ func (q QuerySpec) validateGroupBy() error {
return nil
}

func (q QuerySpec) validateDistinct() error {
if !slice.Contains(q.Features, query.Distinct) {
return nil
}
// 程序走到这q.Select的长度至少为1
for _, c := range q.Select {
// case2,3
if !c.Distinct || !c.Validate() {
return fmt.Errorf("%w: distinct %v", ErrInvalidColumnInfo, c.Name)
}
}
return nil
}

func (q QuerySpec) validateOrderBy() error {
if !slice.Contains(q.Features, query.OrderBy) {
return nil
Expand All @@ -133,9 +174,10 @@ func (q QuerySpec) validateOrderBy() error {
if !c.Validate() {
return fmt.Errorf("%w: orderby %v", ErrInvalidColumnInfo, c.Name)
}
// 清除ASC
c.Order = merger.DESC
if !slice.Contains(q.Select, c) {
_, ok := slice.Find(q.Select, func(src merger.ColumnInfo) bool {
return src.Index == c.Index && src.SelectName() == c.SelectName()
})
if !ok {
return fmt.Errorf("%w: orderby %v", ErrColumnNotFoundInSelectList, c.Name)
}
}
Expand All @@ -159,7 +201,6 @@ func (q QuerySpec) validateLimit() error {

func newAggregateMerger(origin, target QuerySpec) (merger.Merger, error) {
aggregators := getAggregators(origin, target)
log.Printf("aggregators = %#v\n", aggregators)
// TODO: 当aggs为空时, 报不相关的错 merger: scan之前需要调用Next
return aggregatemerger.NewMerger(aggregators...), nil
}
Expand All @@ -171,51 +212,53 @@ func getAggregators(_, target QuerySpec) []aggregator.Aggregator {
switch strings.ToUpper(c.AggregateFunc) {
case "MIN":
aggregators = append(aggregators, aggregator.NewMin(c))
log.Printf("min index = %d\n", c.Index)
case "MAX":
aggregators = append(aggregators, aggregator.NewMax(c))
log.Printf("max index = %d\n", c.Index)
case "AVG":
aggregators = append(aggregators, aggregator.NewAVG(c, target.Select[i+1], target.Select[i+2]))
i += 2
log.Printf("avg index = %d\n", c.Index)
case "SUM":
aggregators = append(aggregators, aggregator.NewSum(c))
log.Printf("sum index = %d\n", c.Index)
case "COUNT":
aggregators = append(aggregators, aggregator.NewCount(c))
log.Printf("count index = %d\n", c.Index)
}
}
return aggregators
}

func newGroupByMergerWithoutHaving(origin, target QuerySpec) (merger.Merger, error) {
aggregators := getAggregators(origin, target)
log.Printf("groupby aggregators = %#v\n", aggregators)
return groupbymerger.NewAggregatorMerger(aggregators, target.GroupBy), nil
}

func newDistinctMerger(_, target QuerySpec) (merger.Merger, error) {
var sortColumns merger.SortColumns
if len(target.OrderBy) != 0 {
s, err := merger.NewSortColumns(target.OrderBy...)
if err != nil {
return nil, err
}
sortColumns = s
}
return distinctmerger.NewMerger(target.Select, sortColumns)
}

func newOrderByMerger(origin, target QuerySpec) (merger.Merger, error) {
var columns []sortmerger.SortColumn
var columns []merger.ColumnInfo
for i := 0; i < len(target.OrderBy); i++ {
c := target.OrderBy[i]
if i < len(origin.OrderBy) && strings.ToUpper(origin.OrderBy[i].AggregateFunc) == "AVG" {
s := sortmerger.NewSortColumn(origin.OrderBy[i].SelectName(), sortmerger.Order(origin.OrderBy[i].Order))
columns = append(columns, s)
columns = append(columns, origin.OrderBy[i])
i++
continue
}
s := sortmerger.NewSortColumn(c.SelectName(), sortmerger.Order(c.Order))
columns = append(columns, s)
columns = append(columns, c)
}

var isPreScanAll bool
if slice.Contains(target.Features, query.GroupBy) {
isPreScanAll = true
}

log.Printf("sortColumns = %#v\n", columns)
return sortmerger.NewMerger(isPreScanAll, columns...)
}

Expand All @@ -228,12 +271,13 @@ func New(origin, target QuerySpec) (merger.Merger, error) {
var mp = map[query.Feature]newMergerFunc{
query.AggregateFunc: newAggregateMerger,
query.GroupBy: newGroupByMergerWithoutHaving,
query.Distinct: newDistinctMerger,
query.OrderBy: newOrderByMerger,
}
var mergers []merger.Merger
for _, feature := range target.Features {
switch feature {
case query.AggregateFunc, query.GroupBy, query.OrderBy:
case query.AggregateFunc, query.GroupBy, query.Distinct, query.OrderBy:
m, err := mp[feature](origin, target)
if err != nil {
return nil, err
Expand All @@ -252,12 +296,13 @@ func New(origin, target QuerySpec) (merger.Merger, error) {
return nil, err
}
mergers = append(mergers, m)
default:
return nil, fmt.Errorf("%w: feature: %d", ErrInvalidFeatures, feature)
}
}
if len(mergers) == 0 {
mergers = append(mergers, batchmerger.NewMerger())
}
log.Printf("mergers = %#v\n", mergers)
return &MergerPipeline{mergers: mergers}, nil
}

Expand All @@ -273,15 +318,11 @@ func (m *MergerPipeline) Merge(ctx context.Context, results []rows.Rows) (rows.R
if len(m.mergers) == 1 {
return r, nil
}
columns, _ := r.Columns()
log.Printf("pipline merge[0] columns = %#v\n", columns)
for _, mg := range m.mergers[1:] {
r, err = mg.Merge(ctx, []rows.Rows{r})
if err != nil {
return nil, err
}
c, _ := r.Columns()
log.Printf("pipline merge[1:] columns = %#v\n", c)
}
return r, nil
}
Expand Down
Loading

0 comments on commit 88085be

Please sign in to comment.