Skip to content

Commit

Permalink
Merge pull request #30 from Adarsh-jaiss/main
Browse files Browse the repository at this point in the history
Added Base64 decoding
  • Loading branch information
tqindia authored May 22, 2024
2 parents 2d9c14e + f5a5810 commit 759e5e8
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 98 deletions.
47 changes: 1 addition & 46 deletions cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,6 @@ var shellCmd = &cobra.Command{
break
}

if dbType == "postgres" {
query = PostgresMetaCommands(query)
}

if err := queryExecute(query, db); err != nil {
fmt.Println("Error executing query:", err)
}
Expand Down Expand Up @@ -199,45 +195,4 @@ func parseDbType(s string) xrayTypes.DbType {
default:
return xrayTypes.MySQL
}
}

// PostgresMetaCommands translates PostgreSQL meta commands to SQL queries
func PostgresMetaCommands(query string) string {
switch query {
case "\\l":
return "SELECT datname FROM pg_database WHERE datistemplate = false;"
case "\\dt":
return "SELECT * FROM pg_catalog.pg_tables;"
case "\\d":
return "SELECT * FROM pg_catalog.pg_tables;"
case "\\c":
return "switch_database"
case "\\q":
return "exit"
case "\\?":
return "help"
case "\\h":
return "help"
case "\\du":
return "SELECT * FROM pg_catalog.pg_roles;"
case "\\conninfo":
return "SELECT * FROM pg_stat_activity WHERE pid = pg_backend_pid();"
default:
// Handle meta commands with parameters
if strings.HasPrefix(query, "\\c ") {
dbName := strings.TrimPrefix(query, "\\c ")
return fmt.Sprintf("switch_database %s", dbName)
} else if strings.HasPrefix(query, "\\d ") {
tableName := strings.TrimPrefix(query, "\\d ")
return fmt.Sprintf("SELECT * FROM %s;", tableName)
} else if strings.HasPrefix(query, "\\dn ") {
schemaName := strings.TrimPrefix(query, "\\dn ")
return fmt.Sprintf("SELECT nspname FROM pg_catalog.pg_namespace WHERE nspname = '%s';", schemaName)
} else if strings.HasPrefix(query, "\\dp ") {
tableName := strings.TrimPrefix(query, "\\dp ")
return fmt.Sprintf("SELECT * FROM pg_catalog.pg_statio_all_tables WHERE relname = '%s';", tableName)
}
}
// If the query doesn't match any known meta commands, return it unchanged
return query
}
}
23 changes: 20 additions & 3 deletions databases/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (b *BigQuery) Execute(query string) ([]byte, error) {
}

// Scan the result into a slice of slices
var results [][]interface{}
var results []map[string]interface{}
for rows.Next() {
// create a slice of values and pointers
values := make([]interface{}, len(columns))
Expand All @@ -139,7 +139,24 @@ func (b *BigQuery) Execute(query string) ([]byte, error) {
return nil, fmt.Errorf("error scanning row: %v", err)
}

results = append(results, values)
// Create a map for the current row
rowMap := make(map[string]interface{})
for i, colName := range columns {
// If the value is of type []byte (which is used for RECORD data types),
// we attempt to unmarshal it into a map[string]interface{}
if b, ok := values[i].([]byte); ok {
var m map[string]interface{}
if err := json.Unmarshal(b, &m); err == nil {
rowMap[colName] = m
} else {
rowMap[colName] = values[i]
}
} else {
rowMap[colName] = values[i]
}
}

results = append(results, rowMap)
}

// Check for errors from iterating over rows
Expand All @@ -148,7 +165,7 @@ func (b *BigQuery) Execute(query string) ([]byte, error) {
}

// Convert the result to JSON
queryResult := types.QueryResult{
queryResult := types.BigQueryResult{
Columns: columns,
Rows: results,
}
Expand Down
28 changes: 28 additions & 0 deletions databases/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"os"
"strings"

"encoding/base64"

_ "github.com/go-sql-driver/mysql"
"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/types"
Expand Down Expand Up @@ -135,6 +137,19 @@ func (m *MySQL) Execute(query string) ([]byte, error) {
return nil, fmt.Errorf("error scanning row: %v", err)
}

// Decode base64 data
for i, val := range values {
strVal, ok := val.(string)
if ok && isBase64(strVal) {
// Redecode the value to get the decoded result
decoded, err := base64.StdEncoding.DecodeString(strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
values[i] = string(decoded)
}
}

results = append(results, values)
}

Expand All @@ -156,6 +171,19 @@ func (m *MySQL) Execute(query string) ([]byte, error) {
return jsonData, nil
}

// isBase64 checks if a string is a valid base64 string.
func isBase64(s string) bool {
if len(s)%4 != 0 {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
}

// Tables retrieves the list of tables in the given database.
// It takes the database name as an argument and returns a list of table names.
func (m *MySQL) Tables(databaseName string) ([]string, error) {
Expand Down
2 changes: 1 addition & 1 deletion databases/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestExecute(t *testing.T) {
// TestGetTableName is a unit test function that tests the Tables method of the MySQL struct.
// It creates a mock instance of MySQL, sets the expected return values, and calls the method under test.
// It then asserts the expected return values and checks if the method was called with the correct arguments.
func TestGetTableName(t *testing.T) {
func TestTables(t *testing.T) {
// create a new mock database connection
db, mock := MockDB()
defer func() {
Expand Down
67 changes: 66 additions & 1 deletion databases/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package postgres

import (
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -137,6 +138,7 @@ func (p *Postgres) Schema(table string) (types.Table, error) {
// It returns an error if the SQL query fails.
func (p *Postgres) Execute(query string) ([]byte, error) {
// execute the sql statement
query = PostgresMetaCommands(query)
rows, err := p.Client.Query(query)
if err != nil {
return nil, fmt.Errorf("error executing sql statement: %v", err)
Expand All @@ -160,14 +162,25 @@ func (p *Postgres) Execute(query string) ([]byte, error) {
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
for i := range values {
// create a slice of pointers to the values
pointers[i] = &values[i]
}

if err := rows.Scan(pointers...); err != nil {
return nil, fmt.Errorf("error scanning row: %v", err)
}

// Decode base64 data
for _, val := range values {
strVal, ok := val.(*string)
if ok && strVal != nil && isBase64(*strVal) {
decoded, err := base64.StdEncoding.DecodeString(*strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
*strVal = string(decoded)
}
}

results = append(results, values)
}

Expand All @@ -189,6 +202,18 @@ func (p *Postgres) Execute(query string) ([]byte, error) {
return jsonData, nil
}

func isBase64(s string) bool {
if len(s)%4 != 0 {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
}

// Tables returns a list of all tables in the given database.
// It returns an error if the SQL query fails.
func (p *Postgres) Tables(databaseName string) ([]string, error) {
Expand Down Expand Up @@ -270,3 +295,43 @@ func TableToString(t types.Table) string {
t.ColumnCount,
)
}

func PostgresMetaCommands(query string) string {
switch query {
case "\\l":
return "SELECT datname FROM pg_database WHERE datistemplate = false;"
case "\\dt":
return "SELECT * FROM pg_catalog.pg_tables;"
case "\\d":
return "SELECT * FROM pg_catalog.pg_tables;"
case "\\c":
return "switch_database"
case "\\q":
return "exit"
case "\\?":
return "help"
case "\\h":
return "help"
case "\\du":
return "SELECT * FROM pg_catalog.pg_roles;"
case "\\conninfo":
return "SELECT * FROM pg_stat_activity WHERE pid = pg_backend_pid();"
default:
// Handle meta commands with parameters
if strings.HasPrefix(query, "\\c ") {
dbName := strings.TrimPrefix(query, "\\c ")
return fmt.Sprintf("switch_database %s", dbName)
} else if strings.HasPrefix(query, "\\d ") {
tableName := strings.TrimPrefix(query, "\\d ")
return fmt.Sprintf("SELECT * FROM %s;", tableName)
} else if strings.HasPrefix(query, "\\dn ") {
schemaName := strings.TrimPrefix(query, "\\dn ")
return fmt.Sprintf("SELECT nspname FROM pg_catalog.pg_namespace WHERE nspname = '%s';", schemaName)
} else if strings.HasPrefix(query, "\\dp ") {
tableName := strings.TrimPrefix(query, "\\dp ")
return fmt.Sprintf("SELECT * FROM pg_catalog.pg_statio_all_tables WHERE relname = '%s';", tableName)
}
}
// If the query doesn't match any known meta commands, return it unchanged
return query
}
25 changes: 25 additions & 0 deletions databases/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redshift
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -165,6 +166,18 @@ func (r *Redshift) Execute(query string) ([]byte, error) {
return nil, fmt.Errorf("error scanning row: %v", err)
}

// Decode base64 data
for _, val := range values {
strVal, ok := val.(*string)
if ok && strVal != nil && isBase64(*strVal) {
decoded, err := base64.StdEncoding.DecodeString(*strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
*strVal = string(decoded)
}
}

results = append(results, values)
}

Expand All @@ -186,6 +199,18 @@ func (r *Redshift) Execute(query string) ([]byte, error) {
return jsonData, nil
}

func isBase64(s string) bool {
if len(s)%4 != 0 {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
}

// GenerateCreateTableQuery generates a CREATE TABLE query for Redshift.
// It takes a Table struct as an argument and returns a string.
func (r *Redshift) GenerateCreateTableQuery(table types.Table) string {
Expand Down
27 changes: 27 additions & 0 deletions databases/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snowflake

import (
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -176,6 +177,19 @@ func (s *Snowflake) Execute(query string) ([]byte, error) {
return nil, fmt.Errorf("error scanning row: %v", err)
}

// Decode base64 data
for i, val := range values {
strVal, ok := val.(string)
if ok && isBase64(strVal) {
// Redecode the value to get the decoded result
decoded, err := base64.StdEncoding.DecodeString(strVal)
if err != nil {
return nil, fmt.Errorf("error decoding base64 data: %v", err)
}
values[i] = string(decoded)
}
}

results = append(results, values)
}

Expand All @@ -198,6 +212,19 @@ func (s *Snowflake) Execute(query string) ([]byte, error) {
return jsonData, nil
}

// isBase64 checks if a string is a valid base64 string.
func isBase64(s string) bool {
if len(s)%4 != 0 {
return false
}
// Try to decode the string
_, err := base64.StdEncoding.DecodeString(s)
// If decoding succeeds, err will be nil, and the function will return true
// If decoding fails, err will not be nil, and the function will return false
// Also we do not have access to decoded value, so we are not using it
return err == nil
}

// GenerateCreateTableQuery generates a CREATE TABLE query for Snowflake.
// It takes a Table struct as an argument and returns the query as a string.
func (s *Snowflake) GenerateCreateTableQuery(table types.Table) string {
Expand Down
Loading

0 comments on commit 759e5e8

Please sign in to comment.