Skip to content

Commit

Permalink
Merge pull request #29 from skpr/provider-model
Browse files Browse the repository at this point in the history
Implement a provider model for RDS/Stdout mysql providers.
  • Loading branch information
fubarhouse authored Jul 24, 2024
2 parents 4acb798 + d4e4d81 commit d92562b
Show file tree
Hide file tree
Showing 19 changed files with 532 additions and 236 deletions.
106 changes: 106 additions & 0 deletions .github/workflows/functional-test.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
/*!40101 SET NAMES utf8mb4 */;
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
/*!40103 SET TIME_ZONE='+00:00' */;
/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;

--
-- Structure for table `orders`
--

DROP TABLE IF EXISTS `orders`;
/*!40101 SET @saved_cs_client = @@character_set_client */;
/*!40101 SET character_set_client = utf8 */;
CREATE TABLE `orders` (
`id` int NOT NULL AUTO_INCREMENT,
`user_id` int DEFAULT NULL,
`total_amount` decimal(10,2) NOT NULL,
`order_date` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
KEY `user_id` (`user_id`),
CONSTRAINT `orders_ibfk_1` FOREIGN KEY (`user_id`) REFERENCES `users` (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
/*!40101 SET character_set_client = @saved_cs_client */;

--
-- Data for table `orders` -- 3 rows
--

LOCK TABLES `orders` WRITE;
/*!40000 ALTER TABLE `orders` DISABLE KEYS */;
set autocommit=0;
INSERT INTO `orders` VALUES (1,1,'1079.98','2024-06-27 04:48:24'),(2,2,'499.99','2024-06-27 04:48:24'),(3,3,'579.98','2024-06-27 04:48:24');
/*!40000 ALTER TABLE `orders` ENABLE KEYS */;
UNLOCK TABLES;
commit;

--
-- Structure for table `products`
--

DROP TABLE IF EXISTS `products`;
/*!40101 SET @saved_cs_client = @@character_set_client */;
/*!40101 SET character_set_client = utf8 */;
CREATE TABLE `products` (
`id` int NOT NULL AUTO_INCREMENT,
`name` varchar(100) NOT NULL,
`price` decimal(10,2) NOT NULL,
`stock` int NOT NULL DEFAULT '0',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
/*!40101 SET character_set_client = @saved_cs_client */;

--
-- Data for table `products` -- 3 rows
--

LOCK TABLES `products` WRITE;
/*!40000 ALTER TABLE `products` DISABLE KEYS */;
set autocommit=0;
INSERT INTO `products` VALUES (1,'Laptop','999.99',50),(2,'Smartphone','499.99',100),(3,'Headphones','79.99',200);
/*!40000 ALTER TABLE `products` ENABLE KEYS */;
UNLOCK TABLES;
commit;

--
-- Structure for table `users`
--

DROP TABLE IF EXISTS `users`;
/*!40101 SET @saved_cs_client = @@character_set_client */;
/*!40101 SET character_set_client = utf8 */;
CREATE TABLE `users` (
`id` int NOT NULL AUTO_INCREMENT,
`username` varchar(50) NOT NULL,
`email` varchar(100) NOT NULL,
`created_at` timestamp NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `username` (`username`),
UNIQUE KEY `email` (`email`)
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
/*!40101 SET character_set_client = @saved_cs_client */;

--
-- Data for table `users` -- 3 rows
--

LOCK TABLES `users` WRITE;
/*!40000 ALTER TABLE `users` DISABLE KEYS */;
set autocommit=0;
INSERT INTO `users` VALUES (1,'john_doe','john@example.com','2024-06-27 04:48:24'),(2,'jane_smith','jane@example.com','2024-06-27 04:48:24'),(3,'bob_johnson','bob@example.com','2024-06-27 04:48:24');
/*!40000 ALTER TABLE `users` ENABLE KEYS */;
UNLOCK TABLES;
commit;
/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */;
/*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */;
/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */;
/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */;
/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */;
48 changes: 48 additions & 0 deletions .github/workflows/functional-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Functional Test

on:
pull_request:
types: [opened, synchronize, reopened]
branches: ['main']

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- uses: actions/setup-go@v3
with:
go-version: '1.21'

- name: Build MTK
run: make build

- name: MySQL:Start
run: sudo /etc/init.d/mysql start

- name: Database:Create
run: mysql -e 'CREATE DATABASE test_db;' -uroot -proot

- name: Database:Import
run: mysql -uroot -proot test_db < .github/workflows/functional-test.sql

- name: Database:Dump
run: bin/mtk-linux-amd64 dump test_db --port=3306 --user=root --password=root --host=127.0.0.1 > mtk-dump.sql

- name: Database:Drop
run: mysql -e 'DROP DATABASE test_db;' -uroot -proot

- name: Database:Clean
# Remove the last line from the file - the dump timestamp.
run: sed '$d' mtk-dump.sql > mtk-dump-clean.sql

- name: Database:Recreate
run: mysql -e 'CREATE DATABASE test_db;' -uroot -proot

- name: Database:Reimport
run: mysql -uroot -proot test_db < mtk-dump-clean.sql

- name: Compare results
run: diff -q .github/workflows/functional-test.sql mtk-dump-clean.sql
32 changes: 11 additions & 21 deletions cmd/mtk/dump/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/spf13/cobra"

"github.com/skpr/mtk/internal/mysql"
"github.com/skpr/mtk/internal/mysql/provider"
"github.com/skpr/mtk/pkg/config"
"github.com/skpr/mtk/pkg/envar"
)
Expand All @@ -31,21 +32,19 @@ const cmdExample = `
# List all database tables and dump each table to a file.
mtk table list <database> | xargs -I {} sh -c "mtk dump <database> '{}' > '{}.sql'"`

// Options is the commandline options for 'config' sub command
// Options is the commandline options for 'dump' sub command
type Options struct {
ConfigFile string
ExtendedInsertRows int

DataExport bool
DataPath string
Region string
}

// NewOptions will return a new Options.
func NewOptions() Options {
return Options{}
}

func NewCommand(conn *mysql.Connection) *cobra.Command {
// NewCommand will return a new Cobra command.
func NewCommand(conn *mysql.Connection, provider, region, s3uri string) *cobra.Command {
o := NewOptions()

cmd := &cobra.Command{
Expand All @@ -72,7 +71,7 @@ func NewCommand(conn *mysql.Connection) *cobra.Command {
panic(err)
}

if err := o.Run(os.Stdout, logger, conn, database, table, cfg); err != nil {
if err := o.Run(os.Stdout, logger, conn, database, table, provider, region, s3uri, cfg); err != nil {
panic(err)
}
},
Expand All @@ -81,22 +80,19 @@ func NewCommand(conn *mysql.Connection) *cobra.Command {
cmd.Flags().StringVar(&o.ConfigFile, "config", envar.GetStringWithFallback("", envar.Config), "Path to the configuration file which contains the rules")
cmd.Flags().IntVar(&o.ExtendedInsertRows, "extended-insert-rows", envar.GetIntWithFallback(1000, envar.ExtendedInsertRows), "The number of rows to batch per INSERT statement")

cmd.Flags().BoolVar(&o.DataExport, "data-export", false, "Export data using SELECT INTO OUTFILE statements.")
cmd.Flags().StringVar(&o.DataPath, "data-path", "", "The S3 bucket URI (e.g s3://my/bucket/path).")
cmd.Flags().StringVar(&o.Region, "region", "", "The S3 bucket region.")

return cmd
}

func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table string, cfg config.Rules) error {
// Run will execute the dump command.
func (o *Options) Run(w io.Writer, logger *log.Logger, conn *mysql.Connection, database, table, provider, region, uri string, cfg config.Rules) error {
db, err := conn.Open(database)
if err != nil {
return fmt.Errorf("failed to open database connection: %w", err)
}

defer db.Close()

client := mysql.NewClient(db, logger)
client := mysql.NewClient(db, logger, provider, region, uri)

if table != "" {
return o.runDumpTable(w, client, table, cfg)
Expand All @@ -118,11 +114,8 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru
return err
}

params := mysql.DumpParams{
params := provider.DumpParams{
ExtendedInsertRows: o.ExtendedInsertRows,
DataExport: o.DataExport,
DataPath: o.DataPath,
Region: o.Region,
}

// Assign nodata tables.
Expand Down Expand Up @@ -165,11 +158,8 @@ func (o *Options) runDumpTables(w io.Writer, client *mysql.Client, cfg config.Ru
//
// eg. runDumpTables has to perform ListTablesByGlobal for each table, which is slow.
func (o *Options) runDumpTable(w io.Writer, client *mysql.Client, table string, cfg config.Rules) error {
params := mysql.DumpParams{
params := provider.DumpParams{
ExtendedInsertRows: o.ExtendedInsertRows,
DataExport: o.DataExport,
DataPath: o.DataPath,
Region: o.Region,
}

// If this table matches an ignore glob, then skip it.
Expand Down
13 changes: 11 additions & 2 deletions cmd/mtk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ import (
"github.com/skpr/mtk/pkg/envar"
)

var conn = new(mysql.Connection)
var (
conn = new(mysql.Connection)
awsRegion string
providerName string
s3Uri string
)

const cmdExample = `
export MTK_HOSTNAME=localhost
Expand Down Expand Up @@ -51,6 +56,10 @@ func init() {
cmd.PersistentFlags().StringVar(&conn.Protocol, "protocol", envar.GetStringWithFallback("tcp", envar.Protocol, envar.MySQLProtocol), "Connection protocol to use when connecting to MySQL instance")
cmd.PersistentFlags().Int32Var(&conn.Port, "port", int32(envar.GetIntWithFallback(3306, envar.Port, envar.MySQLPort)), "Port to connect to the MySQL instance on")
cmd.PersistentFlags().IntVar(&conn.MaxConn, "max-conn", envar.GetIntWithFallback(50, envar.MaxConn), "Sets the maximum number of open connections to the database")

cmd.PersistentFlags().StringVar(&providerName, "provider", envar.GetStringWithFallback("stdout", envar.Provider), "The provider to use (either 'stdout' or 'rds')")
cmd.PersistentFlags().StringVar(&awsRegion, "region", envar.GetStringWithFallback("", envar.Region), "The AWS region to use for S3 when connecting to the MySQL RDS instance")
cmd.PersistentFlags().StringVar(&s3Uri, "uri", envar.GetStringWithFallback("", envar.Uri), "The S3 URI to use for exporting to S3 when exporting data from the MySQL RDS instance")
}

func main() {
Expand All @@ -68,7 +77,7 @@ func main() {
usageTemplate = re.ReplaceAllLiteralString(usageTemplate, `{{StyleHeading "Flags:"}}`)
cmd.SetUsageTemplate(usageTemplate)

cmd.AddCommand(dump.NewCommand(conn))
cmd.AddCommand(dump.NewCommand(conn, providerName, awsRegion, s3Uri))
cmd.AddCommand(table.NewCommand(conn))

if err := cmd.Execute(); err != nil {
Expand Down
1 change: 1 addition & 0 deletions cmd/mtk/table/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const cmdExample = `
# List all database tables.
mtk table list <database>`

// NewCommand will execute the table command.
func NewCommand(conn *mysql.Connection) *cobra.Command {
cmd := &cobra.Command{
Use: "table",
Expand Down
4 changes: 2 additions & 2 deletions cmd/mtk/table/list/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const cmdExample = `
# List all database tables and dump each table to a file.
mtk table list <database> | xargs -I {} sh -c "mtk dump <database> '{}' > '{}.sql'"`

// Options is the commandline options for 'config' sub command
// Options is the commandline options for 'list' sub command
type Options struct {
ConfigFile string
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func (o *Options) Run(logger *log.Logger, conn *mysql.Connection, database strin

defer db.Close()

client := mysql.NewClient(db, logger)
client := mysql.NewClient(db, logger, "", "", "")

tables, err := client.QueryTables()
if err != nil {
Expand Down
38 changes: 20 additions & 18 deletions internal/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"log"

"github.com/go-sql-driver/mysql"

"github.com/skpr/mtk/internal/mysql/provider"
)

const (
Expand All @@ -16,6 +18,7 @@ const (
OperationNoData = "nodata"
)

// Connection is a struct containing metadata for the database connection.
type Connection struct {
Hostname string
Username string
Expand All @@ -25,6 +28,7 @@ type Connection struct {
MaxConn int
}

// Open will Open a new database connection.
func (o Connection) Open(database string) (*sql.DB, error) {
cfg := mysql.Config{
User: o.Username,
Expand All @@ -49,33 +53,31 @@ func (o Connection) Open(database string) (*sql.DB, error) {
type Client struct {
DB *sql.DB
Logger *log.Logger

// A field for caching a list of tables for this database.
cachedTables []string

// Provider configuration.
Provider string
// For the AWS RDS Provider, specify the AWS Region.
Region string
// For the AWS RDS Provider, specify the S3 URI.
URI string
}

// NewClient for dumping a full or single table from a database.
func NewClient(db *sql.DB, logger *log.Logger) *Client {
func NewClient(db *sql.DB, logger *log.Logger, provider, region, uri string) *Client {
return &Client{
DB: db,
Logger: logger,
DB: db,
Logger: logger,
Provider: provider,
Region: region,
URI: uri,
}
}

// DumpParams is used to pass parameters to the Dump function.
type DumpParams struct {
SelectMap map[string]map[string]string
WhereMap map[string]string
FilterMap map[string]string
UseTableLock bool
ExtendedInsertRows int

DataExport bool
DataPath string
Region string
}

// DumpTables will write all table data to a single writer.
func (d *Client) DumpTables(w io.Writer, params DumpParams) error {
func (d *Client) DumpTables(w io.Writer, params provider.DumpParams) error {
if err := d.WriteHeader(w); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
Expand All @@ -96,7 +98,7 @@ func (d *Client) DumpTables(w io.Writer, params DumpParams) error {
}

// DumpTable is convenient if you wish to coordinate a dump eg. Single file per table.
func (d *Client) DumpTable(w io.Writer, table string, params DumpParams) error {
func (d *Client) DumpTable(w io.Writer, table string, params provider.DumpParams) error {
if err := d.WriteHeader(w); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
Expand Down
Loading

0 comments on commit d92562b

Please sign in to comment.