Skip to content

Commit

Permalink
Merge pull request #26 from Adarsh-jaiss/main
Browse files Browse the repository at this point in the history
added support for mssql
  • Loading branch information
tqindia authored May 18, 2024
2 parents 707873c + cdd4ccf commit 7c6dc5b
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 3 deletions.
13 changes: 13 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/databases/bigquery"
"github.com/thesaas-company/xray/databases/mssql"

"github.com/thesaas-company/xray/databases/mysql"
"github.com/thesaas-company/xray/databases/postgres"
Expand Down Expand Up @@ -49,6 +50,12 @@ func NewClientWithConfig(dbConfig *config.Config, dbType types.DbType) (types.IS
return nil, err
}
return logger.NewLogger(redshiftClient), nil
case types.MSSQL:
mssqlClient, err := mssql.NewMSSQLFromConfig(dbConfig)
if err != nil {
return nil, err
}
return logger.NewLogger(mssqlClient), nil

default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
Expand Down Expand Up @@ -90,6 +97,12 @@ func NewClient(dbClient *sql.DB, dbType types.DbType) (types.ISQL, error) {
return nil, err
}
return logger.NewLogger(redshiftClient), nil
case types.MSSQL:
mssqlClient, err := mssql.NewMSSQL(dbClient)
if err != nil {
return nil, err
}
return logger.NewLogger(mssqlClient), nil
default:
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
Expand Down
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type Config struct {
// SecretName is the AWS secret name.
SecretName string `yaml:"secret_name" pflag:",AWS secret name"`

// Server is the MSSQL database server.
Server string `yaml:"server" pflag:",Database server"`

// AWS holds the AWS configuration details.
// AWS AWS `yaml:"aws"`
}
Expand Down
205 changes: 205 additions & 0 deletions databases/mssql/mssql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package mssql

import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"strings"

_ "github.com/denisenkom/go-mssqldb"
"github.com/thesaas-company/xray/config"
"github.com/thesaas-company/xray/types"
)

var DB_PASSWORD = "DB_PASSWORD"

const (
MSSQL_SCHEMA_QUERY = "SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, ORDINAL_POSITION, CHARACTER_MAXIMUM_LENGTH FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '%s'"
MSSQL_TABLES_QUERY = "USE %s; SELECT table_name FROM INFORMATION_SCHEMA.TABLES;"
)

type MSSQL struct {
Client *sql.DB
Config *config.Config
}

func NewMSSQL(client *sql.DB) (types.ISQL, error) {
return &MSSQL{
Client: client,
Config: &config.Config{},
}, nil
}

func NewMSSQLFromConfig(config *config.Config) (types.ISQL, error) {
if os.Getenv(DB_PASSWORD) == "" || len(os.Getenv(DB_PASSWORD)) == 0 { // added mysql to be more verbose about the db type
return nil, fmt.Errorf("please set %s env variable for the database", DB_PASSWORD)
}

DB_PASSWORD = os.Getenv(DB_PASSWORD)
connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%s", config.Server, config.Username, DB_PASSWORD, config.Port)

conn, err := sql.Open("mssql", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}

return &MSSQL{
Client: conn,
Config: config,
}, nil
}

func (m *MSSQL) Schema(table string) (types.Table, error) {
query := fmt.Sprintf(MSSQL_SCHEMA_QUERY, table)
rows, err := m.Client.Query(query)
if err != nil {
return types.Table{}, fmt.Errorf("error executing sql statement: %v", err)
}
defer func() {
if err := rows.Close(); err != nil {
log.Println("Failed to close rows:", err)
}
}()

var columns []types.Column
for rows.Next() {
var col types.Column
if err := rows.Scan(
&col.Name,
&col.Type,
&col.IsNullable,
&col.ColumnDefault,
&col.OrdinalPosition,
&col.CharacterMaximumLength,
); err != nil {
return types.Table{}, fmt.Errorf("error scanning rows : %v", err)
}
col.Description = "" // default description
col.Metatags = []string{} // default metatags as an empty string slice
col.Metatags = append(col.Metatags, col.Name)
col.Visibility = true // default visibility
columns = append(columns, col)
}

if err := rows.Err(); err != nil {
return types.Table{}, fmt.Errorf("error iterating over rows: %v", err)
}

return types.Table{
Name: table,
Columns: columns,
ColumnCount: int64(len(columns)),
Metatags: []string{},
}, nil
}

func (m *MSSQL) Tables(databaseName string) ([]string, error) {
query := fmt.Sprintf(MSSQL_TABLES_QUERY, databaseName)
rows, err := m.Client.Query(query)
if err != nil {
return nil, fmt.Errorf("error executing the sql statement: %v", err)
}

defer func() {
if err := rows.Close(); err != nil {
fmt.Printf("error closing the rows: %v", err)
}
}()

var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, fmt.Errorf("error scanning the database :%v", err)
}
tables = append(tables, table)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over rows :%v", err)
}

return tables, nil

}

func (m *MSSQL) Execute(query string) ([]byte, error) {
rows, err := m.Client.Query(query)
if err != nil {
return nil, fmt.Errorf("error executing the sql statement %v", err)
}

defer func() {
if err := rows.Close(); err != nil {
log.Println("failed to close rows:", err)
}
}()

columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("error getting columns : %v", err)
}

var results [][]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
pointers := make([]interface{}, len(columns))
for i := range values {
pointers[i] = &values[i]
}

if err := rows.Scan(pointers...); err != nil {
return nil, fmt.Errorf("error scanning rows:%v", err)
}

results = append(results, values)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over rows : %v", err)
}

queryResult := types.QueryResult{
Columns: columns,
Rows: results,
}

jsonData, err := json.Marshal(queryResult)
if err != nil {
return nil, fmt.Errorf("error marshalling json: %v", err)
}

return jsonData, nil
}

func (m *MSSQL) GenerateCreateTableQuery(table types.Table) string {
query := "CREATE TABLE [" + table.Name + "] ("
pk := ""
unique := ""
for i, column := range table.Columns {
colType := strings.ToUpper(column.Type)
query += "[" + column.Name + "] " + colType
if column.AutoIncrement {
query += " IDENTITY(1,1)"
}
if column.IsPrimary {
pk = " PRIMARY KEY ([" + column.Name + "])"
}
if column.DefaultValue.Valid {
query += " DEFAULT (" + column.DefaultValue.String + ")"
}
if column.IsUnique.String == "YES" && !column.IsPrimary {
unique = ", UNIQUE ([" + column.Name + "])"
}
if column.IsNullable == "NO" && !column.IsPrimary {
query += " NOT NULL"
}
if i < len(table.Columns)-1 {
query += ", "
}
}
query += pk + unique + ")"
return query
}
Loading

0 comments on commit 7c6dc5b

Please sign in to comment.