Skip to content

Commit

Permalink
Refactor long functions
Browse files Browse the repository at this point in the history
  • Loading branch information
noborus committed Dec 11, 2023
1 parent b35cc62 commit 97f3e57
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package trdsql

import (
"context"
"database/sql"
"log"
"strings"

Expand All @@ -18,6 +19,9 @@ type Exporter interface {
// WriteFormat represents a structure that satisfies Exporter.
type WriteFormat struct {
Writer
columns []string
types []string
multi bool
}

// NewExporter returns trdsql default Exporter.
Expand All @@ -38,18 +42,20 @@ func (e *WriteFormat) Export(db *DB, sql string) error {
// ExportContext is called from ExecContext.
func (e *WriteFormat) ExportContext(ctx context.Context, db *DB, sqlQuery string) error {
queries := sqlss.SplitQueries(sqlQuery)
e.multi = false
if !multi || len(queries) == 1 {
return e.exportContext(ctx, db, false, sqlQuery)
e.multi = true
return e.exportContext(ctx, db, sqlQuery)
}
for _, query := range queries {
if err := e.exportContext(ctx, db, true, query); err != nil {
if err := e.exportContext(ctx, db, query); err != nil {
return err
}
}
return nil
}

func (e *WriteFormat) exportContext(ctx context.Context, db *DB, multi bool, query string) error {
func (e *WriteFormat) exportContext(ctx context.Context, db *DB, query string) error {
if db.Tx == nil {
return ErrNoTransaction
}
Expand All @@ -73,6 +79,7 @@ func (e *WriteFormat) exportContext(ctx context.Context, db *DB, multi bool, que
if err != nil {
return err
}
e.columns = columns

defer func() {
if err = rows.Close(); err != nil {
Expand All @@ -81,14 +88,9 @@ func (e *WriteFormat) exportContext(ctx context.Context, db *DB, multi bool, que
}()

// No data is not output for multiple queries.
if multi && len(columns) == 0 {
if e.multi && len(e.columns) == 0 {
return nil
}
values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}

columnTypes, err := rows.ColumnTypes()
if err != nil {
Expand All @@ -98,8 +100,19 @@ func (e *WriteFormat) exportContext(ctx context.Context, db *DB, multi bool, que
for i, ct := range columnTypes {
types[i] = ct.DatabaseTypeName()
}
e.types = types

return e.write(ctx, rows)
}

func (e *WriteFormat) write(ctx context.Context, rows *sql.Rows) error {
values := make([]interface{}, len(e.columns))
scanArgs := make([]interface{}, len(e.columns))
for i := range values {
scanArgs[i] = &values[i]
}

if err := e.Writer.PreWrite(columns, types); err != nil {
if err := e.Writer.PreWrite(e.columns, e.types); err != nil {
return err
}

Expand All @@ -113,7 +126,7 @@ func (e *WriteFormat) exportContext(ctx context.Context, db *DB, multi bool, que
if err := rows.Scan(scanArgs...); err != nil {
return err
}
if err := e.Writer.WriteRow(values, columns); err != nil {
if err := e.Writer.WriteRow(values, e.columns); err != nil {
return err
}
}
Expand Down

0 comments on commit 97f3e57

Please sign in to comment.