Skip to content

Commit

Permalink
*: fix some audit log error (#26767)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored Aug 5, 2021
1 parent c08de09 commit 072cf27
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 5 deletions.
8 changes: 7 additions & 1 deletion bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,13 @@ func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("EXPLAIN FORMAT='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = origVals
if rs != nil {
defer terror.Call(rs.Close)
defer func() {
// Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'.
origin := sctx.GetSessionVars().InRestrictedSQL
sctx.GetSessionVars().InRestrictedSQL = true
terror.Call(rs.Close)
sctx.GetSessionVars().InRestrictedSQL = origin
}()
}
if err != nil {
return "", err
Expand Down
1 change: 1 addition & 0 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ func (a *ExecStmt) logAudit() {
if sessVars.InRestrictedSQL {
return
}

err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
audit := plugin.DeclareAuditManifest(p.Manifest)
if audit.OnGeneralEvent != nil {
Expand Down
7 changes: 7 additions & 0 deletions executor/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ func (e *SQLBindExec) dropSQLBind() error {
}

func (e *SQLBindExec) createSQLBind() error {
// For audit log, SQLBindExec execute "explain" statement internally, save and recover stmtctx
// is necessary to avoid 'create binding' been recorded as 'explain'.
saveStmtCtx := e.ctx.GetSessionVars().StmtCtx
defer func() {
e.ctx.GetSessionVars().StmtCtx = saveStmtCtx
}()

bindInfo := bindinfo.Binding{
BindSQL: e.bindSQL,
Charset: e.charset,
Expand Down
9 changes: 9 additions & 0 deletions executor/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@ func GetStmtLabel(stmtNode ast.StmtNode) string {
case *ast.DropIndexStmt:
return "DropIndex"
case *ast.DropTableStmt:
if x.IsView {
return "DropView"
}
return "DropTable"
case *ast.ExplainStmt:
return "Explain"
Expand Down Expand Up @@ -373,6 +376,12 @@ func GetStmtLabel(stmtNode ast.StmtNode) string {
return "CreateBinding"
case *ast.IndexAdviseStmt:
return "IndexAdvise"
case *ast.DropBindingStmt:
return "DropBinding"
case *ast.TraceStmt:
return "Trace"
case *ast.ShutdownStmt:
return "Shutdown"
}
return "other"
}
14 changes: 11 additions & 3 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ type PrepareExec struct {
ID uint32
ParamCount int
Fields []*ast.ResultField

// If it's generated from executing "prepare stmt from '...'", the process is parse -> plan -> executor
// If it's generated from the prepare protocol, the process is session.PrepareStmt -> NewPrepareExec
// They both generate a PrepareExec struct, but the second case needs to reset the statement context while the first already do that.
needReset bool
}

// NewPrepareExec creates a new PrepareExec.
Expand All @@ -96,6 +101,7 @@ func NewPrepareExec(ctx sessionctx.Context, sqlTxt string) *PrepareExec {
return &PrepareExec{
baseExecutor: base,
sqlText: sqlTxt,
needReset: true,
}
}

Expand Down Expand Up @@ -135,9 +141,11 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
stmt := stmts[0]

err = ResetContextOfStmt(e.ctx, stmt)
if err != nil {
return err
if e.needReset {
err = ResetContextOfStmt(e.ctx, stmt)
if err != nil {
return err
}
}

var extractor paramMarkerExtractor
Expand Down
14 changes: 14 additions & 0 deletions executor/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ func (e *TraceExec) Next(ctx context.Context, req *chunk.Chunk) error {
return nil
}

// For audit log plugin to set the correct statement.
stmtCtx := e.ctx.GetSessionVars().StmtCtx
defer func() {
e.ctx.GetSessionVars().StmtCtx = stmtCtx
}()

switch e.format {
case core.TraceFormatLog:
return e.nextTraceLog(ctx, se, req)
Expand Down Expand Up @@ -130,6 +136,14 @@ func (e *TraceExec) nextRowJSON(ctx context.Context, se sqlexec.SQLExecutor, req
}

func (e *TraceExec) executeChild(ctx context.Context, se sqlexec.SQLExecutor) {
// For audit log plugin to log the statement correctly.
// Should be logged as 'explain ...', instead of the executed SQL.
vars := e.ctx.GetSessionVars()
origin := vars.InRestrictedSQL
vars.InRestrictedSQL = true
defer func() {
vars.InRestrictedSQL = origin
}()
rs, err := se.ExecuteStmt(ctx, e.stmtNode)
if err != nil {
var errCode uint16
Expand Down
162 changes: 162 additions & 0 deletions plugin/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package plugin_test

import (
"bytes"
"context"
"fmt"
"strconv"
"testing"

"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/testutils"
)

type testAuditLogSuite struct {
cluster testutils.Cluster
store kv.Storage
dom *domain.Domain

bytes.Buffer
}

func (s *testAuditLogSuite) setup(t *testing.T) {
pluginName := "test_audit_log"
pluginVersion := uint16(1)
pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion))

config.UpdateGlobal(func(conf *config.Config) {
conf.Plugin.Load = pluginSign
})

// setup load test hook.
loadOne := func(p *plugin.Plugin, dir string, pluginID plugin.ID) (manifest func() *plugin.Manifest, err error) {
return func() *plugin.Manifest {
m := &plugin.AuditManifest{
Manifest: plugin.Manifest{
Kind: plugin.Audit,
Name: pluginName,
Version: pluginVersion,
OnInit: OnInit,
OnShutdown: OnShutdown,
Validate: Validate,
},
OnGeneralEvent: s.OnGeneralEvent,
OnConnectionEvent: OnConnectionEvent,
}
return plugin.ExportManifest(m)
}, nil
}
plugin.SetTestHook(loadOne)

store, err := mockstore.NewMockStore(
mockstore.WithClusterInspector(func(c testutils.Cluster) {
mockstore.BootstrapWithSingleStore(c)
s.cluster = c
}),
)
require.NoError(t, err)
s.store = store
session.SetSchemaLease(0)
session.DisableStats4Test()

d, err := session.BootstrapSession(s.store)
require.NoError(t, err)
d.SetStatsUpdating(true)
s.dom = d
}

func (s *testAuditLogSuite) teardown() {
s.dom.Close()
s.store.Close()
}

func TestAuditLog(t *testing.T) {
var s testAuditLogSuite
s.setup(t)
defer s.teardown()

var buf1 bytes.Buffer
tk := testkit.NewAsyncTestKit(t, s.store)
ctx := tk.OpenSession(context.Background(), "test")
buf1.WriteString("Use use `test`\n") // Workaround for the testing framework.

tk.MustExec(ctx, "use test")
buf1.WriteString("Use use `test`\n")

tk.MustExec(ctx, "create table t (id int primary key, a int, b int unique)")
buf1.WriteString("CreateTable create table `t` ( `id` int primary key , `a` int , `b` int unique )\n")

tk.MustExec(ctx, "create view v1 as select * from t where id > 2")
buf1.WriteString("CreateView create view `v1` as select * from `t` where `id` > ?\n")

tk.MustExec(ctx, "drop view v1")
buf1.WriteString("DropView drop view `v1`\n")

tk.MustExec(ctx, "create session binding for select * from t where b = 123 using select * from t ignore index(b) where b = 123")
buf1.WriteString("CreateBinding create session binding for select * from `t` where `b` = ? using select * from `t` where `b` = ?\n")

tk.MustExec(ctx, "prepare mystmt from 'select ? as num from DUAL'")
buf1.WriteString("Prepare prepare `mystmt` from ?\n")

tk.MustExec(ctx, "set @number = 5")
buf1.WriteString("Set set @number = ?\n")

tk.MustExec(ctx, "execute mystmt using @number")
buf1.WriteString("Select select ? as `num` from dual\n")

tk.MustQuery(ctx, "trace format = 'row' select * from t")
buf1.WriteString("Trace trace format = ? select * from `t`\n")

tk.MustExec(ctx, "shutdown")
buf1.WriteString("Shutdown shutdown\n")

require.Equal(t, buf1.String(), s.Buffer.String())
}

func Validate(ctx context.Context, m *plugin.Manifest) error {
return nil
}

// OnInit implements TiDB plugin's OnInit SPI.
func OnInit(ctx context.Context, manifest *plugin.Manifest) error {
return nil
}

// OnShutdown implements TiDB plugin's OnShutdown SPI.
func OnShutdown(ctx context.Context, manifest *plugin.Manifest) error {
return nil
}

// OnGeneralEvent implements TiDB Audit plugin's OnGeneralEvent SPI.
func (s *testAuditLogSuite) OnGeneralEvent(ctx context.Context, sctx *variable.SessionVars, event plugin.GeneralEvent, cmd string) {
if sctx != nil {
normalized, _ := sctx.StmtCtx.SQLDigest()
fmt.Fprintln(&s.Buffer, sctx.StmtCtx.StmtType, normalized)
}
}

// OnConnectionEvent implements TiDB Audit plugin's OnConnectionEvent SPI.
func OnConnectionEvent(ctx context.Context, event plugin.ConnectionEvent, info *variable.ConnectionInfo) error {
return nil
}
9 changes: 8 additions & 1 deletion plugin/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,12 @@ import (

func TestMain(m *testing.M) {
testbridge.WorkaroundGoCheckFlags()
goleak.VerifyTestMain(m)

opts := []goleak.Option{
goleak.IgnoreTopFunction("go.etcd.io/etcd/pkg/logutil.(*MergeLogger).outputLoop"),
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
goleak.IgnoreTopFunction("time.Sleep"),
}

goleak.VerifyTestMain(m, opts...)
}

0 comments on commit 072cf27

Please sign in to comment.