Skip to content

Commit

Permalink
Merge pull request #21 from platform9/private/atherton/jayanth/AIR-1365
Browse files Browse the repository at this point in the history
[AIR-1365] Mysql 8.0 compatibility
  • Loading branch information
jayanth-tjvrr authored Jul 22, 2024
2 parents 60cbbb7 + b218859 commit f402c71
Showing 1 changed file with 171 additions and 136 deletions.
307 changes: 171 additions & 136 deletions pkg/cfg/cfgmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
package cfg

import (
"errors"
"fmt"
rand "math/rand"
"reflect"
"sort"
"time"

consul "github.com/hashicorp/consul/api"
"database/sql"
_ "github.com/go-sql-driver/mysql"
"go.uber.org/zap"
"errors"
"fmt"
"log"
rand "math/rand"
"reflect"
"sort"
"time"

"database/sql"

_ "github.com/go-sql-driver/mysql"
consul "github.com/hashicorp/consul/api"
"go.uber.org/zap"
)

// ErrorNotFound signifies absence of SSO configuration
Expand Down Expand Up @@ -101,7 +103,7 @@ func (c *CfgMgr) GetKeystonePassword(serviceName string) (string, error) {
return password, nil
}

// AddKeystoneUser
// AddKeystoneUser
func (c *CfgMgr) AddKeystoneUser(serviceName string) error {
zap.L().Debug("Creating keystone user for serviceName ", zap.String("serviceName", serviceName))

Expand Down Expand Up @@ -151,146 +153,179 @@ func (c *CfgMgr) getValue(key string) (string, error) {
return "", fmt.Errorf("Key value for key %s not found", key)
}

func (c * CfgMgr) GetRandomPassword() string {
rand.Seed(time.Now().UnixNano())
digits := "0123456789"
lowerChars := "abcdefghijklmnopqrstuvwxyz"
upperChars := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
allChars := lowerChars + upperChars + digits
const max = 16
out := make([]byte, max)
for i := 0; i < max; i++ {
out[i] = allChars[rand.Intn(len(allChars))]
}
return string(out)
func (c *CfgMgr) GetRandomPassword() string {
rand.Seed(time.Now().UnixNano())
digits := "0123456789"
lowerChars := "abcdefghijklmnopqrstuvwxyz"
upperChars := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
allChars := lowerChars + upperChars + digits
const max = 16
out := make([]byte, max)
for i := 0; i < max; i++ {
out[i] = allChars[rand.Intn(len(allChars))]
}
return string(out)
}

func (c *CfgMgr) CreateDB(serviceName, userName string) (updateConsul bool, err error) {
dbObject, err := c.getDbObject()
if err != nil {
return false, err
}

zap.L().Debug("Creating DB for serviceName ", zap.String("serviceName", serviceName))
_, err = dbObject.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", serviceName))
if err != nil {
zap.L().Error("Error while creating database", zap.Error(err))
return false, err
}
zap.L().Info(fmt.Sprintf("Created DB '%s' successfully", serviceName))
return true, nil
dbObject, err := c.getDbObject()
if err != nil {
return false, err
}

zap.L().Debug("Creating DB for serviceName ", zap.String("serviceName", serviceName))
_, err = dbObject.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", serviceName))
if err != nil {
zap.L().Error("Error while creating database", zap.Error(err))
return false, err
}
zap.L().Info(fmt.Sprintf("Created DB '%s' successfully", serviceName))
return true, nil
}

func (c *CfgMgr) CreateGrants(dbName, userName, dbPassword string) (bool, error) {
dbObject, err := c.getDbObject()
if err != nil {
return false, err
}
rows, err := dbObject.Query("SELECT @@hostname")
if err != nil {
zap.L().Error("Error while getting hostname", zap.Error(err))
return false, err
}
defer rows.Close()
var hostname string
count := 0
for rows.Next() {
_ = rows.Scan(&hostname)
}
hosts := []string{"localhost", "%", hostname}
for _, hostName := range hosts {
before_grants := c.getGrants(userName, hostName, dbObject)
query := fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s' IDENTIFIED BY '%s'",
dbName, userName, hostName, dbPassword)
_, _ = dbObject.Exec(query)
after_grants := c.getGrants(userName, hostName, dbObject)
if (reflect.DeepEqual(before_grants, after_grants)) {
count += 1
}
}
if count == len(hosts) {
return false, nil
} else {
return true, nil
}
dbObject, err := c.getDbObject()
if err != nil {
return false, err
}
rows, err := dbObject.Query("SELECT @@hostname")
if err != nil {
zap.L().Error("Error while getting hostname", zap.Error(err))
return false, err
}
defer rows.Close()
var hostname string
count := 0
for rows.Next() {
_ = rows.Scan(&hostname)
}
hosts := []string{"localhost", "%", hostname}
for _, hostName := range hosts {
before_grants := c.getGrants(userName, hostName, dbObject)

var exists bool
query := "SELECT EXISTS(SELECT 1 FROM mysql.user WHERE user = ? AND host = ?)"
err := dbObject.QueryRow(query, userName, hostName).Scan(&exists)
if err != nil {
log.Fatalf("Error checking if user exists: %v", err)
}

if exists {
// Update the password if the user exists
updatePasswordQuery := "ALTER USER '" + userName + "'@'" + hostName + "' IDENTIFIED BY '" + dbPassword + "'"
_, err := dbObject.Exec(updatePasswordQuery)
if err != nil {
log.Fatalf("Error updating user password: %v", err)
}
} else {
// Create the user if it does not exist
createUserQuery := "CREATE USER '" + userName + "'@'" + hostName + "' IDENTIFIED BY '" + dbPassword + "'"
_, err := dbObject.Exec(createUserQuery)
if err != nil {
log.Fatalf("Error creating user: %v", err)
}
fmt.Printf("User '%s'@'%s' created successfully.\n", userName, hostName)
}

grantPrivilegesQuery := "GRANT ALL PRIVILEGES ON `" + dbName + "`.* TO '" + userName + "'@'" + hostName + "'"
_, err = dbObject.Exec(grantPrivilegesQuery)

if err != nil {
// Handle the error
fmt.Println("Error granting privileges:", err)
}

// query := fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO '%s'@'%s' IDENTIFIED BY '%s'",
// dbName, userName, hostName, dbPassword)
// _, _ = dbObject.Exec(query)
after_grants := c.getGrants(userName, hostName, dbObject)
if reflect.DeepEqual(before_grants, after_grants) {
count += 1
}
}
if count == len(hosts) {
return false, nil
} else {
return true, nil
}
}

func (c *CfgMgr) getDbDetails() (string, string, string, string, error) {
dbserver, err := c.getValue(fmt.Sprintf("%s/keystone/dbserver_key", c.CustomerKeyPrefix))
if err != nil {
zap.L().Error("Cannot get dbserver_key from consul store", zap.Error(err))
return "", "", "", "", err
}
host, err := c.getValue(fmt.Sprintf("%s/host", dbserver))
if err != nil {
zap.L().Error("Cannot get host key from consul store", zap.Error(err))
return "", "", "", "", err
}
port, err := c.getValue(fmt.Sprintf("%s/port", dbserver))
if err != nil {
zap.L().Error("Cannot get port key from consul store", zap.Error(err))
return "", "", "", "", err
}
adminUser, err := c.getValue(fmt.Sprintf("%s/admin_user", dbserver))
if err != nil {
zap.L().Error("Cannot get admin_user key from consul store", zap.Error(err))
return "", "", "", "", err
}
adminPass, err := c.getValue(fmt.Sprintf("%s/admin_pass", dbserver))
if err != nil {
zap.L().Error("Cannot get admin_pass key from consul store", zap.Error(err))
return "", "", "", "", err
}
return host, port, adminUser, adminPass, nil
dbserver, err := c.getValue(fmt.Sprintf("%s/keystone/dbserver_key", c.CustomerKeyPrefix))
if err != nil {
zap.L().Error("Cannot get dbserver_key from consul store", zap.Error(err))
return "", "", "", "", err
}
host, err := c.getValue(fmt.Sprintf("%s/host", dbserver))
if err != nil {
zap.L().Error("Cannot get host key from consul store", zap.Error(err))
return "", "", "", "", err
}
port, err := c.getValue(fmt.Sprintf("%s/port", dbserver))
if err != nil {
zap.L().Error("Cannot get port key from consul store", zap.Error(err))
return "", "", "", "", err
}
adminUser, err := c.getValue(fmt.Sprintf("%s/admin_user", dbserver))
if err != nil {
zap.L().Error("Cannot get admin_user key from consul store", zap.Error(err))
return "", "", "", "", err
}
adminPass, err := c.getValue(fmt.Sprintf("%s/admin_pass", dbserver))
if err != nil {
zap.L().Error("Cannot get admin_pass key from consul store", zap.Error(err))
return "", "", "", "", err
}
return host, port, adminUser, adminPass, nil
}

func (c *CfgMgr) getDbObject() (*sql.DB, error) {
host, port, adminUser, adminPass, err := c.getDbDetails()
if err != nil {
return nil, err
}
dbObject, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/", adminUser, adminPass, host, port))
if err != nil {
zap.L().Error("Can't connect to MySQL", zap.Error(err))
return nil, err
}
return dbObject, nil
host, port, adminUser, adminPass, err := c.getDbDetails()
if err != nil {
return nil, err
}
dbObject, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/", adminUser, adminPass, host, port))
if err != nil {
zap.L().Error("Can't connect to MySQL", zap.Error(err))
return nil, err
}
return dbObject, nil
}

func (c *CfgMgr) UpdateConsul(serviceName, userName, dbPassword string) error {
host, port, _, _, err := c.getDbDetails()
if err != nil {
return err
}
dbPrefix := fmt.Sprintf("%s/%s/db", c.CustomerKeyPrefix, serviceName)
ops := consul.KVTxnOps{
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/name", dbPrefix), Value: []byte(serviceName)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/password", dbPrefix), Value: []byte(dbPassword)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/user", dbPrefix), Value: []byte(userName)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/host", dbPrefix), Value: []byte(host)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/port", dbPrefix), Value: []byte(port)},
}
_, _, _, err = c.ConsulKV.Txn(ops, nil)
if err != nil {
zap.L().Error("Can't write db config to Consul", zap.Error(err))
return err
}
return nil
host, port, _, _, err := c.getDbDetails()
if err != nil {
return err
}
dbPrefix := fmt.Sprintf("%s/%s/db", c.CustomerKeyPrefix, serviceName)
ops := consul.KVTxnOps{
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/name", dbPrefix), Value: []byte(serviceName)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/password", dbPrefix), Value: []byte(dbPassword)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/user", dbPrefix), Value: []byte(userName)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/host", dbPrefix), Value: []byte(host)},
&consul.KVTxnOp{Verb: consul.KVSet, Key: fmt.Sprintf("%s/port", dbPrefix), Value: []byte(port)},
}
_, _, _, err = c.ConsulKV.Txn(ops, nil)
if err != nil {
zap.L().Error("Can't write db config to Consul", zap.Error(err))
return err
}
return nil
}

func (c *CfgMgr) getGrants(userName, host string, dbObject *sql.DB) []string {
var grants []string
var field string
rows, err := dbObject.Query(fmt.Sprintf("SHOW GRANTS FOR %s@%s", userName, host))
if err != nil {
zap.L().Error("Error while getting grants for user")
return grants
}
defer rows.Close()
for rows.Next() {
_ = rows.Scan(&field)
}
grants = append(grants, field)
sort.Strings(grants)
return grants
var grants []string
var field string
rows, err := dbObject.Query(fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", userName, host))
if err != nil {
zap.L().Error("Error while getting grants for user", zap.Error(err))
return grants
}
defer rows.Close()
for rows.Next() {
_ = rows.Scan(&field)
}
grants = append(grants, field)
sort.Strings(grants)
return grants
}

0 comments on commit f402c71

Please sign in to comment.