Skip to content

Commit

Permalink
Merge pull request #790 from ClickHouse/issue_741
Browse files Browse the repository at this point in the history
Enforce sort order of columns as specified in INSERT
  • Loading branch information
gingerwizard authored Oct 20, 2022
2 parents 7445e6c + a9f26c8 commit 99cf774
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 12 deletions.
14 changes: 14 additions & 0 deletions conn_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ import (
)

var splitInsertRe = regexp.MustCompile(`(?i)\sVALUES\s*\(`)
var columnMatch = regexp.MustCompile(`.*\((?P<Columns>.+)\)$`)

func (c *connect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) {
query = splitInsertRe.Split(query, -1)[0]
colMatch := columnMatch.FindStringSubmatch(query)
var columns []string
if len(colMatch) == 2 {
columns = strings.Split(colMatch[1], ",")
for i := range columns {
columns[i] = strings.TrimSpace(columns[i])
}
}
if !strings.HasSuffix(strings.TrimSpace(strings.ToUpper(query)), "VALUES") {
query += " VALUES"
}
Expand All @@ -54,6 +63,10 @@ func (c *connect) prepareBatch(ctx context.Context, query string, release func(*
release(c, err)
return nil, err
}
// resort batch to specified columns
if err = block.SortColumns(columns); err != nil {
return nil, err
}
return &batch{
ctx: ctx,
conn: c,
Expand Down Expand Up @@ -90,6 +103,7 @@ func (b *batch) Append(v ...interface{}) error {
if b.sent {
return ErrBatchAlreadySent
}
//
if err := b.block.Append(v...); err != nil {
b.release(err)
return err
Expand Down
48 changes: 37 additions & 11 deletions conn_http_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,27 @@ import (
"io"
"io/ioutil"
"regexp"
"strings"
)

var splitHttpInsertRe = regexp.MustCompile("(?i)^INSERT INTO\\s+`?([\\w.]+)`?")
// \x60 represents a backtick
var httpInsertRe = regexp.MustCompile(`(?i)^INSERT INTO\s+\x60?([\w.^\(]+)\x60?\s*(\([^\)]*\))?`)

// release is ignored, because http used by std with empty release function
func (h *httpConnect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) {
index := splitHttpInsertRe.FindStringSubmatchIndex(query)

if len(index) < 3 {
matches := httpInsertRe.FindStringSubmatch(query)
if len(matches) < 3 {
return nil, errors.New("cannot get table name from query")
}

tableName := query[index[2]:index[3]]
tableName := matches[1]
var rColumns []string
if matches[2] != "" {
colMatch := strings.TrimSuffix(strings.TrimPrefix(matches[2], "("), ")")
rColumns = strings.Split(colMatch, ",")
for i := range rColumns {
rColumns[i] = strings.TrimSpace(rColumns[i])
}
}
query = "INSERT INTO " + tableName + " FORMAT Native"
queryTableSchema := "DESCRIBE TABLE " + tableName
r, err := h.query(ctx, release, queryTableSchema)
Expand All @@ -50,21 +58,39 @@ func (h *httpConnect) prepareBatch(ctx context.Context, query string, release fu
block := &proto.Block{}

// get Table columns and types
columns := make(map[string]string)
var colNames []string
for r.Next() {
var (
colName string
colType string
ignore string
)

err = r.Scan(&colName, &colType, &ignore, &ignore, &ignore, &ignore, &ignore)
if err != nil {
if err = r.Scan(&colName, &colType, &ignore, &ignore, &ignore, &ignore, &ignore); err != nil {
return nil, err
}
colNames = append(colNames, colName)
columns[colName] = colType
}

err = block.AddColumn(colName, column.Type(colType))
if err != nil {
return nil, err
switch len(rColumns) {
case 0:
for _, colName := range colNames {
if err = block.AddColumn(colName, column.Type(columns[colName])); err != nil {
return nil, err
}
}
default:
// user has requested specific columns so only include these
for _, colName := range rColumns {
if colType, ok := columns[colName]; ok {
if err = block.AddColumn(colName, column.Type(colType)); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("column %s is not present in the table %s", colName, tableName)
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions lib/proto/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package proto
import (
"errors"
"fmt"
"sort"
"time"

"github.com/ClickHouse/ch-go/proto"
Expand Down Expand Up @@ -73,6 +74,49 @@ func (b *Block) ColumnsNames() []string {
return b.names
}

// SortColumns sorts our block according to the requested order - a slice of column names. Names must be identical in requested order and block.
func (b *Block) SortColumns(columns []string) error {
if len(columns) == 0 {
// no preferred sort order
return nil
}
if len(columns) != len(b.Columns) {
return fmt.Errorf("requested column order is incorrect length to sort block - expected %d, got %d", len(b.Columns), len(columns))
}
missing := difference(b.names, columns)
if len(missing) > 0 {
return fmt.Errorf("block cannot be sorted - missing columns in requested order: %v", missing)
}
lookup := make(map[string]int)
for i, col := range columns {
lookup[col] = i
}
// we assume both lists have the same
sort.Slice(b.Columns, func(i, j int) bool {
iRank, jRank := lookup[b.Columns[i].Name()], lookup[b.Columns[j].Name()]
return iRank < jRank
})
sort.Slice(b.names, func(i, j int) bool {
iRank, jRank := lookup[b.names[i]], lookup[b.names[j]]
return iRank < jRank
})
return nil
}

func difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}

func (b *Block) Encode(buffer *proto.Buffer, revision uint64) error {
if revision > 0 {
encodeBlockInfo(buffer)
Expand Down
187 changes: 187 additions & 0 deletions tests/issues/741_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package issues

import (
"context"
"fmt"
"github.com/ClickHouse/clickhouse-go/v2"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
clickhouse_std_tests "github.com/ClickHouse/clickhouse-go/v2/tests/std"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"math/rand"
"net"
"strconv"
"strings"
"testing"
"time"
)

func TestIssue741(t *testing.T) {
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
protocols := []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP}
for _, protocol := range protocols {
t.Run(fmt.Sprintf("%v Protocol", protocol), func(t *testing.T) {
conn, err := clickhouse_std_tests.GetDSNConnection("issues", protocol, useSSL, "false")
require.NoError(t, err)
conn.Exec("DROP TABLE IF EXISTS issue_741")
ddl := `
CREATE TABLE issue_741 (
Col1 String,
Col2 Int64
)
Engine MergeTree() ORDER BY tuple()
`
_, err = conn.Exec(ddl)
require.NoError(t, err)
defer func() {
conn.Exec("DROP TABLE issue_741")
}()
stmt, err := conn.Prepare("INSERT INTO issue_741 (Col2, Col1) VALUES (? ?)")
_, err = stmt.Exec(int64(1), "1")
require.NoError(t, err)
})
}
}

func TestIssue741SingleColumn(t *testing.T) {
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
protocols := []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP}
for _, protocol := range protocols {
t.Run(fmt.Sprintf("%v Protocol", protocol), func(t *testing.T) {
conn, err := clickhouse_std_tests.GetDSNConnection("issues", protocol, useSSL, "false")
require.NoError(t, err)
conn.Exec("DROP TABLE IF EXISTS issue_741_single")
ddl := `
CREATE TABLE issue_741_single (
Col1 String,
Col2 Int64
)
Engine MergeTree() ORDER BY tuple()
`
_, err = conn.Exec(ddl)
require.NoError(t, err)
defer func() {
conn.Exec("DROP TABLE issue_741_single")
}()
stmt, err := conn.Prepare("INSERT INTO issue_741_single (Col1) VALUES (?)")
_, err = stmt.Exec("1")
require.NoError(t, err)
})
}
}

func generateRandomInsert(tableName string) (string, string, []interface{}) {
columns := map[string]interface{}{
"Col1 String": "a",
"Col2 Int64": int64(1),
"Col3 Int32": int32(2),
"Col4 Bool": true,
"Col5 Date32": time.Now(),
"Col6 IPv4": net.ParseIP("8.8.8.8"),
"Col7 Decimal32(5)": decimal.New(25, 0),
"Col8 UUID": uuid.New(),
}
colNames := make([]string, len(columns))
i := 0
for k := range columns {
colNames[i] = k
i++
}
// shuffle our columns for ddl
rand.Shuffle(len(colNames), func(i, j int) { colNames[i], colNames[j] = colNames[j], colNames[i] })
ddl := fmt.Sprintf(`
CREATE TABLE %s (
%s
)
Engine MergeTree() ORDER BY tuple()`, tableName, strings.Join(colNames, ", "))
// shuffle our columns for insert
rand.Shuffle(len(colNames), func(i, j int) { colNames[i], colNames[j] = colNames[j], colNames[i] })
names := make([]string, len(colNames))
placeholders := make([]string, len(colNames))
for i := range colNames {
names[i] = strings.Fields(colNames[i])[0]
placeholders[i] = "?"
}
insertStatement := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(names, ", "), strings.Join(placeholders, ", "))
values := make([]interface{}, len(colNames))
for i, colName := range colNames {
values[i] = columns[colName]
}
return ddl, insertStatement, values
}

func TestIssue741RandomOrder(t *testing.T) {
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
protocols := []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP}
for _, protocol := range protocols {
t.Run(fmt.Sprintf("%v Protocol", protocol), func(t *testing.T) {
conn, err := clickhouse_std_tests.GetDSNConnection("issues", clickhouse.Native, useSSL, "false")
require.NoError(t, err)
conn.Exec("DROP TABLE IF EXISTS issue_741_random")
defer func() {
conn.Exec("DROP TABLE issue_741_random")
}()
ddl, insertStatement, values := generateRandomInsert("issue_741_random")
_, err = conn.Exec(ddl)
require.NoError(t, err)
stmt, err := conn.Prepare(fmt.Sprintf(insertStatement))
require.NoError(t, err)
_, err = stmt.Exec(values...)
require.NoError(t, err)
})
}
}

// test Append on native connection
func TestIssue741NativeAppend(t *testing.T) {
var (
conn, err = clickhouse_tests.GetConnection("issues", clickhouse.Settings{
"max_execution_time": 60,
}, nil, &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
})
)
ctx := context.Background()
require.NoError(t, err)
conn.Exec(ctx, "DROP TABLE IF EXISTS issue_741_append_random")
defer func() {
conn.Exec(ctx, "DROP TABLE issue_741_append_random")
}()
ddl, insertStatement, values := generateRandomInsert("issue_741_append_random")
require.NoError(t, conn.Exec(ctx, ddl))
batch, err := conn.PrepareBatch(ctx, insertStatement)
require.NoError(t, err)
require.NoError(t, batch.Append(values...))
require.NoError(t, batch.Send())
}

// test Append on native connection
func TestIssue741StdAppend(t *testing.T) {
//test http and native
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
protocols := []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP}
for _, protocol := range protocols {
t.Run(fmt.Sprintf("%v Protocol", protocol), func(t *testing.T) {
conn, err := clickhouse_std_tests.GetDSNConnection("issues", clickhouse.Native, useSSL, "false")
require.NoError(t, err)
conn.Exec("DROP TABLE IF EXISTS issue_741_std_append_random")
defer func() {
conn.Exec("DROP TABLE issue_741_std_append_random")
}()
ddl, insertStatement, values := generateRandomInsert("issue_741_std_append_random")
_, err = conn.Exec(ddl)
require.NoError(t, err)
scope, err := conn.Begin()
require.NoError(t, err)
batch, err := scope.Prepare(insertStatement)
require.NoError(t, err)
_, err = batch.Exec(values...)
require.NoError(t, err)
})
}
}
2 changes: 1 addition & 1 deletion tests/std/connect_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestStdConnCheck(t *testing.T) {
Value String
) Engine MergeTree() ORDER BY tuple()
`
dml = `INSERT INTO clickhouse_test_conn_check VALUES `
dml = "INSERT INTO `clickhouse_test_conn_check` VALUES "
)

env, err := GetStdTestEnvironment()
Expand Down

0 comments on commit 99cf774

Please sign in to comment.