Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export slow queries #792

Merged
merged 16 commits into from
Nov 24, 2020
85 changes: 60 additions & 25 deletions pkg/apiserver/slowquery/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,48 +106,78 @@ type GetListRequest struct {
Fields string `json:"fields" form:"fields"` // example: "Query,Digest"
}

func getProjectionsByFields(jsonFields ...string) ([]string, error) {
fields := make(map[string]*reflect.StructField)
t := reflect.TypeOf(SlowQuery{})
fieldsNum := t.NumField()
for i := 0; i < fieldsNum; i++ {
field := t.Field(i)
fields[strings.ToLower(field.Tag.Get("json"))] = &field
var cachedProjectionsMap map[string]string

func getProjectionsMap() map[string]string {
if cachedProjectionsMap == nil {
t := reflect.TypeOf(SlowQuery{})
fieldsNum := t.NumField()
ret := map[string]string{}
for i := 0; i < fieldsNum; i++ {
field := t.Field(i)
// ignore to check error because the field is defined by ourself
// we can confirm that it has "gorm" tag and fixed structure
s, _ := field.Tag.Lookup("gorm")
jsonField := strings.ToLower(field.Tag.Get("json"))
sourceField := strings.Split(s, ":")[1]
if proj, ok := field.Tag.Lookup("proj"); ok {
ret[jsonField] = fmt.Sprintf("%s AS %s", proj, sourceField)
} else {
ret[jsonField] = sourceField
}
}
cachedProjectionsMap = ret
}
return cachedProjectionsMap
}

func getProjectionsByFields(jsonFields ...string) ([]string, error) {
projMap := getProjectionsMap()
ret := make([]string, 0, len(jsonFields))
for _, fieldName := range jsonFields {
field, ok := fields[strings.ToLower(fieldName)]
field, ok := projMap[strings.ToLower(fieldName)]
if !ok {
return nil, fmt.Errorf("unknown field %s", fieldName)
}
// ignore to check error because the field is defined by ourself
// we can confirm that it has "gorm" tag and fixed structure
s, _ := field.Tag.Lookup("gorm")
sourceField := strings.Split(s, ":")[1]
if proj, ok := field.Tag.Lookup("proj"); ok {
ret = append(ret, fmt.Sprintf("%s AS %s", proj, sourceField))
} else {
ret = append(ret, sourceField)
}
ret = append(ret, field)
}
return ret, nil
}

var cachedAllProjections []string

func getAllProjections() []string {
if cachedAllProjections == nil {
projMap := getProjectionsMap()
ret := make([]string, 0, len(projMap))
for _, proj := range projMap {
ret = append(ret, proj)
}
cachedAllProjections = ret
}
return cachedAllProjections
}

type GetDetailRequest struct {
Digest string `json:"digest" form:"digest"`
Timestamp float64 `json:"timestamp" form:"timestamp"`
ConnectID int64 `json:"connect_id" form:"connect_id"`
}

func QuerySlowLogList(db *gorm.DB, req *GetListRequest) ([]SlowQuery, error) {
sqlFields := []string{"digest", "connection_id", "timestamp"}
if strings.TrimSpace(req.Fields) != "" {
sqlFields = append(sqlFields, strings.Split(req.Fields, ",")...)
sqlFields = funk.UniqString(sqlFields)
}
projections, err := getProjectionsByFields(sqlFields...)
if err != nil {
return nil, err
var projections []string
var err error
reqFields := strings.Split(req.Fields, ",")
if len(reqFields) == 1 && reqFields[0] == "*" {
projections = getAllProjections()
} else {
projections, err = getProjectionsByFields(
funk.UniqString(
append([]string{"digest", "connection_id", "timestamp"}, reqFields...),
)...)
if err != nil {
return nil, err
}
}

tx := db.
Expand All @@ -174,6 +204,11 @@ func QuerySlowLogList(db *gorm.DB, req *GetListRequest) ([]SlowQuery, error) {
tx = tx.Where("DB IN (?)", req.DB)
}

// more robust
if req.OrderBy == "" {
req.OrderBy = "timestamp"
}

order, err := getProjectionsByFields(req.OrderBy)
if err != nil {
return nil, err
Expand Down
81 changes: 77 additions & 4 deletions pkg/apiserver/slowquery/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
package slowquery

import (
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
"go.uber.org/fx"
Expand All @@ -39,10 +43,17 @@ func NewService(p ServiceParams) *Service {

func Register(r *gin.RouterGroup, auth *user.AuthService, s *Service) {
endpoint := r.Group("/slow_query")
endpoint.Use(auth.MWAuthRequired())
endpoint.Use(utils.MWConnectTiDB(s.params.TiDBClient))
endpoint.GET("/list", s.listHandler)
endpoint.GET("/detail", s.detailhandler)
{
endpoint.GET("/download", s.downloadHandler)
endpoint.Use(auth.MWAuthRequired())
endpoint.Use(utils.MWConnectTiDB(s.params.TiDBClient))
{
endpoint.GET("/list", s.listHandler)
endpoint.GET("/detail", s.detailhandler)

endpoint.POST("/download/token", s.downloadTokenHandler)
}
}
}

// @Summary List all slow queries
Expand Down Expand Up @@ -88,3 +99,65 @@ func (s *Service) detailhandler(c *gin.Context) {
}
c.JSON(http.StatusOK, *result)
}

// @Router /slow_query/download/token [post]
// @Summary Generate a download token for exported slow query statements
// @Produce plain
// @Param request body GetListRequest true "Request body"
// @Success 200 {string} string "xxx"
// @Security JwtAuth
// @Failure 401 {object} utils.APIError "Unauthorized failure"
func (s *Service) downloadTokenHandler(c *gin.Context) {
var req GetListRequest
if err := c.ShouldBindJSON(&req); err != nil {
utils.MakeInvalidRequestErrorFromError(c, err)
return
}
db := utils.GetTiDBConnection(c)
fields := []string{}
if strings.TrimSpace(req.Fields) != "" {
fields = strings.Split(req.Fields, ",")
}
list, err := QuerySlowLogList(db, &req)
if err != nil {
_ = c.Error(err)
return
}
if len(list) == 0 {
utils.MakeInvalidRequestErrorFromError(c, errors.New("no data to export"))
unbyte marked this conversation as resolved.
Show resolved Hide resolved
return
}

// interface{} tricky
rawData := make([]interface{}, len(list))
for i, v := range list {
rawData[i] = v
}

// convert data
csvData := utils.GenerateCSVFromRaw(rawData, fields, []string{})

// generate temp file that persist encrypted data
timeLayout := "0102150405"
currentTime := time.Now().Format(timeLayout)
token, err := utils.ExportCSV(csvData,
fmt.Sprintf("slowquery_%s_*.csv", currentTime),
unbyte marked this conversation as resolved.
Show resolved Hide resolved
"slowquery/download")

if err != nil {
_ = c.Error(err)
return
}
c.String(http.StatusOK, token)
}

// @Router /slow_query/download [get]
// @Summary Download slow query statements
// @Produce text/csv
// @Param token query string true "download token"
// @Failure 400 {object} utils.APIError
// @Failure 401 {object} utils.APIError "Unauthorized failure"
func (s *Service) downloadHandler(c *gin.Context) {
token := c.Query("token")
utils.DownloadByToken(token, "slowquery/download", c)
}
52 changes: 31 additions & 21 deletions pkg/apiserver/statement/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,50 @@ type Model struct {
RelatedSchemas string `json:"related_schemas"`
}

func getAggrFields(sqlFields ...string) []string {
fields := make(map[string]*reflect.StructField)
t := reflect.TypeOf(Model{})
fieldsNum := t.NumField()
for i := 0; i < fieldsNum; i++ {
field := t.Field(i)
fields[strings.ToLower(field.Tag.Get("json"))] = &field
var cachedAggrMap map[string]string // jsonFieldName => aggr

func getAggrMap() map[string]string {
if cachedAggrMap == nil {
t := reflect.TypeOf(Model{})
fieldsNum := t.NumField()
ret := map[string]string{}
for i := 0; i < fieldsNum; i++ {
field := t.Field(i)
jsonField := strings.ToLower(field.Tag.Get("json"))
if agg, ok := field.Tag.Lookup("agg"); ok {
ret[jsonField] = fmt.Sprintf("%s AS %s", agg, gorm.ToColumnName(field.Name))
}
}
cachedAggrMap = ret
}
return cachedAggrMap
}

func getAggrFields(sqlFields ...string) []string {
aggrMap := getAggrMap()
ret := make([]string, 0, len(sqlFields))
for _, fieldName := range sqlFields {
if field, ok := fields[strings.ToLower(fieldName)]; ok {
if agg, ok := field.Tag.Lookup("agg"); ok {
ret = append(ret, fmt.Sprintf("%s AS %s", agg, gorm.ToColumnName(field.Name)))
} else {
panic(fmt.Sprintf("field %s cannot be aggregated", fieldName))
}
if aggr, ok := aggrMap[strings.ToLower(fieldName)]; ok {
ret = append(ret, aggr)
} else {
panic(fmt.Sprintf("unknown aggregation field %s", fieldName))
}
}
return ret
}

var cachedAllAggrFields []string

func getAllAggrFields() []string {
t := reflect.TypeOf(Model{})
fieldsNum := t.NumField()
ret := make([]string, 0, fieldsNum)
for i := 0; i < fieldsNum; i++ {
field := t.Field(i)
if agg, ok := field.Tag.Lookup("agg"); ok {
ret = append(ret, fmt.Sprintf("%s AS %s", agg, gorm.ToColumnName(field.Name)))
if cachedAllAggrFields == nil {
aggrMap := getAggrMap()
ret := make([]string, 0, len(aggrMap))
for _, aggr := range aggrMap {
ret = append(ret, aggr)
}
cachedAllAggrFields = ret
}
return ret
return cachedAllAggrFields
}

// tableNames example: "d1.a1,d2.a2,d1.a1,d3.a3"
Expand Down
Loading