diff --git a/README.md b/README.md index b3f213ce..0df5892b 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ wireguard interface stats. See the `cap_add` and `network_mode` options on the d Set the `SESSION_SECRET` environment variable to a random value. -In order to sent the wireguard configuration to clients via email, set the following environment variables: +In order to send the wireguard configuration to clients via email, set the following environment variables: - using SendGrid API @@ -55,6 +55,22 @@ EMAIL_FROM_ADDRESS: the sender's email address EMAIL_FROM_NAME: the sender's name ``` +In order to connect to a database, set the following environment +variables: + +``` +DB_TYPE +DB_HOST +DB_PORT +DB_DATABASE +DB_USERNAME +DB_PASSWORD +DB_TLS: the TLS option +``` + +For details on the values that these variables should be set to, see the +section for your desired database. + ### Using binary file Download the binary file from the release and run it with command: @@ -63,6 +79,61 @@ Download the binary file from the release and run it with command: ./wireguard-ui ``` +## Databases + +By default, all the data for the application is stored in JSON files in +the `./db` directory. By using the `--db-type` command line option or by +setting the `DB_TYPE` environment variable, you can choose to use a +different backend. Note: for some backends, other options may need to be +set. + +Backend options: + +| Value | Database | Other options | +| ----- | -------- | ------------- | +| jsondb | JSON files in `./db` | None | +| mysql | MySQL or MariaDB server | `DB_HOST` `DB_PORT` `DB_DATABASE` `DB_USERNAME` `DB_PASSWORD` `DB_TLS` | + +### JSONDB + +When using the JSONDB database, all of the data is stored in separate +JSON files in the `./db` directory. This is the default database and no +special configuration is required. + +### MySQL + +In order to use a MySQL or MariaDB server, you will first have to set +the `DB_TYPE` environment variable to `mysql`. You should then specify +the hostname or IP address of the database server using `DB_HOST` as +well as the port on which the database server is listening, if it is +different from the default of `3306`. `DB_DATABASE` is the name of the +database that WireGuard-UI is to use. Please ensure that the database is +empty before you start WireGuard-UI for the first time otherwise the +tables will not be initialized properly. `DB_USERNAME` and `DB_PASSWORD` +should contain the login details for a user with the following +permissions for the database: + +* SELECT +* INSERT +* UPDATE +* DELETE +* CREATE +* ALTER + +`DB_TLS` sets the TLS configuration for the database connection. It +defaults to `false` and can be one of the following values: + +| Option | Description | +| ------ | ----------- | +| false | Never use TLS (default) | +| true | Enable TLS / SSL encrypted connection to the server | +| prefered | Use TLS when advertised by the server | +| skip-verify | Use TLS, but don't check against a CA | + +After you have set these options, you should be able to start the +WireGuard-UI server. The server will then initialize the database and +insert the default configuration. If this process is interrupted, you +will have to empty the database and restart the initialization. ## Auto restart WireGuard daemon WireGuard-UI only takes care of configuration generation. You can use systemd to watch for the changes and restart the service. Following is an example: diff --git a/go.mod b/go.mod index 918d7040..a358cd82 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/GeertJohan/go.rice v1.0.0 github.com/glendc/go-external-ip v0.0.0-20170425150139-139229dcdddd github.com/go-playground/universal-translator v0.17.0 // indirect + github.com/go-sql-driver/mysql v1.6.0 github.com/gorilla/sessions v1.2.0 github.com/jcelliott/lumber v0.0.0-20160324203708-dd349441af25 // indirect github.com/labstack/echo-contrib v0.9.0 diff --git a/go.sum b/go.sum index 206090f4..6ebb75f8 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8c github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/handler/routes.go b/handler/routes.go index 9c18fc40..6f839236 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -148,7 +148,7 @@ func NewClient(db store.IStore) echo.HandlerFunc { } // validate the input Allocation IPs - allocatedIPs, err := util.GetAllocatedIPs("") + allocatedIPs, err := util.GetAllocatedIPs(db, "") check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, client.AllocatedIPs) if !check { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)}) @@ -268,7 +268,7 @@ func UpdateClient(db store.IStore) echo.HandlerFunc { } client := *clientData.Client // validate the input Allocation IPs - allocatedIPs, err := util.GetAllocatedIPs(client.ID) + allocatedIPs, err := util.GetAllocatedIPs(db, client.ID) check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, _client.AllocatedIPs) if !check { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)}) @@ -624,7 +624,7 @@ func SuggestIPAllocation(db store.IStore) echo.HandlerFunc { // we take the first available ip address from // each server's network addresses. suggestedIPs := make([]string, 0) - allocatedIPs, err := util.GetAllocatedIPs("") + allocatedIPs, err := util.GetAllocatedIPs(db, "") if err != nil { log.Error("Cannot suggest ip allocation. Failed to get list of allocated ip addresses: ", err) return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{ diff --git a/main.go b/main.go index 99e04607..8a44b43d 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,9 @@ import ( "github.com/ngoduykhanh/wireguard-ui/emailer" "github.com/ngoduykhanh/wireguard-ui/handler" "github.com/ngoduykhanh/wireguard-ui/router" + "github.com/ngoduykhanh/wireguard-ui/store" "github.com/ngoduykhanh/wireguard-ui/store/jsondb" + "github.com/ngoduykhanh/wireguard-ui/store/mysqldb" "github.com/ngoduykhanh/wireguard-ui/util" ) @@ -35,6 +37,13 @@ var ( flagEmailFrom string flagEmailFromName string = "WireGuard UI" flagSessionSecret string + flagDBType string = "jsondb" + flagDBHost string = "localhost" + flagDBPort int = 3306 + flagDBDatabase string = "wireguard-ui" + flagDBUsername string + flagDBPassword string + flagDBTLS string = "false" ) const ( @@ -61,6 +70,13 @@ func init() { flag.StringVar(&flagEmailFrom, "email-from", util.LookupEnvOrString("EMAIL_FROM_ADDRESS", flagEmailFrom), "'From' email address.") flag.StringVar(&flagEmailFromName, "email-from-name", util.LookupEnvOrString("EMAIL_FROM_NAME", flagEmailFromName), "'From' email name.") flag.StringVar(&flagSessionSecret, "session-secret", util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret), "The key used to encrypt session cookies.") + flag.StringVar(&flagDBType, "db-type", util.LookupEnvOrString("DB_TYPE", flagDBType), "Type of database to use. [jsondb, mysql]") + flag.StringVar(&flagDBHost, "db-host", util.LookupEnvOrString("DB_HOST", flagDBHost), "Database host") + flag.IntVar(&flagDBPort, "db-port", util.LookupEnvOrInt("DB_PORT", flagDBPort), "Database port") + flag.StringVar(&flagDBDatabase, "db-database", util.LookupEnvOrString("DB_DATABASE", flagDBDatabase), "Database name") + flag.StringVar(&flagDBUsername, "db-username", util.LookupEnvOrString("DB_USERNAME", flagDBUsername), "Database username") + flag.StringVar(&flagDBPassword, "db-password", util.LookupEnvOrString("DB_PASSWORD", flagDBPassword), "Database password") + flag.StringVar(&flagDBTLS, "db-tls", util.LookupEnvOrString("DB_TLS", flagDBTLS), "TLS mode. [true, false, skip-verify, preferred]") flag.Parse() // update runtime config @@ -76,6 +92,13 @@ func init() { util.EmailFrom = flagEmailFrom util.EmailFromName = flagEmailFromName util.SessionSecret = []byte(flagSessionSecret) + util.DBType = flagDBType + util.DBHost = flagDBHost + util.DBPort = flagDBPort + util.DBDatabase = flagDBDatabase + util.DBUsername = flagDBUsername + util.DBPassword = flagDBPassword + util.DBTLS = flagDBTLS // print app information fmt.Println("Wireguard UI") @@ -89,18 +112,12 @@ func init() { //fmt.Println("Sendgrid key\t:", util.SendgridApiKey) fmt.Println("Email from\t:", util.EmailFrom) fmt.Println("Email from name\t:", util.EmailFromName) + fmt.Println("Datastore\t:", util.DBType) //fmt.Println("Session secret\t:", util.SessionSecret) } func main() { - db, err := jsondb.New("./db") - if err != nil { - panic(err) - } - if err := db.Init(); err != nil { - panic(err) - } // set app extra data extraData := make(map[string]string) extraData["appVersion"] = appVersion @@ -111,6 +128,23 @@ func main() { // rice file server for assets. "assets" is the folder where the files come from. assetHandler := http.FileServer(rice.MustFindBox("assets").HTTPBox()) + // Configure database + var db store.IStore + var err error + switch util.DBType { + case "jsondb": + db, err = jsondb.New("./db") + case "mysql": + db, err = mysqldb.New(util.DBUsername, util.DBPassword, util.DBHost, util.DBPort, util.DBDatabase, util.DBTLS, tmplBox) + } + + if err != nil { + panic(err) + } + if err := db.Init(); err != nil { + panic(err) + } + // register routes app := router.New(tmplBox, extraData, util.SessionSecret) diff --git a/store/mysqldb/mysqldb.go b/store/mysqldb/mysqldb.go new file mode 100644 index 00000000..58290b7f --- /dev/null +++ b/store/mysqldb/mysqldb.go @@ -0,0 +1,514 @@ +// Package mysqldb provides a MySQL storage backend for Wireguard UI +package mysqldb + +import ( + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + rice "github.com/GeertJohan/go.rice" + "github.com/go-sql-driver/mysql" + "github.com/skip2/go-qrcode" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/ngoduykhanh/wireguard-ui/util" +) + +// MySQLDB - Representation of MySQL database backend +type MySQLDB struct { + conn *sql.DB + schema string + dbName string +} + +// String to split each item in array +var arrayDelimiter = "," + +// New returns pointer to MySQL database +func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) { + // Set connection config + config := mysql.NewConfig() + config.User = uname + config.Passwd = pwd + config.Net = "tcp" + config.Addr = fmt.Sprintf("%s:%d", host, port) + config.DBName = database + config.MultiStatements = true + config.ParseTime = true + config.TLSConfig = tls + + // Open connection pool + conn, err := sql.Open("mysql", config.FormatDSN()) + if err != nil { + return nil, err + } + conn.SetConnMaxLifetime(time.Minute * 3) + conn.SetMaxOpenConns(10) + conn.SetMaxIdleConns(10) + + // Test the connection + if err := conn.Ping(); err != nil { + return nil, err + } + + // Load DB schema + schema, err := templateBox.String("mysql.sql") + if err != nil { + return nil, err + } + + ans := MySQLDB{ + conn: conn, + schema: schema, + dbName: database, + } + return &ans, nil +} + +// Init initializes the database +func (o *MySQLDB) Init() error { + // Check if database is empty + var databaseEmpty int + err := o.conn.QueryRow( + "SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?", + o.dbName, + ).Scan(&databaseEmpty) + if err != nil { + return err + } + + if !(databaseEmpty > 0) { + // Initialize database + // Tell the user what we're doing as this could take a while + fmt.Println("Initializing database") + + // Create database schema + if _, err := o.conn.Exec(o.schema); err != nil { + return err + } + + // servers's interface + if _, err := o.conn.Exec( + "INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);", + util.DefaultServerAddress, + util.DefaultServerPort, + time.Now().UTC(), + ); err != nil { + return err + } + + // server's keypair + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return err + } + + if _, err := o.conn.Exec( + "INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);", + key.String(), + key.PublicKey().String(), + time.Now().UTC(), + ); err != nil { + return err + } + + // global settings + publicInterface, err := util.GetPublicIP() + if err != nil { + return err + } + + if _, err := o.conn.Exec( + "INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);", + publicInterface.IPAddress, + util.DefaultDNS, + util.DefaultMTU, + util.DefaultPersistentKeepalive, + util.DefaultConfigFilePath, + time.Now().UTC(), + ); err != nil { + return err + } + + // user info + if _, err := o.conn.Exec( + "INSERT INTO users (username, password) VALUES (?, ?);", + util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername), + util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword), + ); err != nil { + return err + } + } + + return nil +} + +// GetUser func to query user info from the database +func (o *MySQLDB) GetUser() (model.User, error) { + user := model.User{} + row := o.conn.QueryRow("SELECT username, password FROM users;") + err := row.Scan( + &user.Username, + &user.Password, + ) + return user, err +} + +// GetGlobalSettings func to query global settings from the database +func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) { + settings := model.GlobalSetting{} + var dnsServers string + + row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;") + // Can't use ScanStruct here as doesn't know how to handle + // dns_servers list. Instead we must populate struct it manually. + err := row.Scan( + &settings.EndpointAddress, + &dnsServers, + &settings.MTU, + &settings.PersistentKeepalive, + &settings.ConfigFilePath, + &settings.UpdatedAt, + ) + settings.DNSServers = strings.Split(dnsServers, arrayDelimiter) + return settings, err +} + +// GetServer func to query Server setting from the database +func (o *MySQLDB) GetServer() (model.Server, error) { + server := model.Server{} + + // Get interface + serverInterface := model.ServerInterface{} + var addresses string + + row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;") + err := row.Scan( + &addresses, + &serverInterface.ListenPort, + &serverInterface.UpdatedAt, + &serverInterface.PostUp, + &serverInterface.PostDown, + ) + serverInterface.Addresses = strings.Split(addresses, arrayDelimiter) + if err != nil { + return server, err + } + + // Get keypair + serverKeyPair := model.ServerKeypair{} + if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;"). + Scan( + &serverKeyPair.PrivateKey, + &serverKeyPair.PublicKey, + &serverKeyPair.UpdatedAt, + ); err != nil { + return server, err + } + + // create Server object and return + server.Interface = &serverInterface + server.KeyPair = &serverKeyPair + return server, nil +} + +// GetClients func to query Client settings from the database +func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) { + var clients []model.ClientData + + rows, err := o.conn.Query("SELECT * FROM clients;") + if err != nil { + return clients, err + } + + for rows.Next() { + client := model.Client{} + clientData := model.ClientData{} + var allocatedIPs string + var allowedIPs string + var extraAllowedIPs string + + // Get client info + if err := rows.Scan( + &client.ID, + &client.PrivateKey, + &client.PublicKey, + &client.PresharedKey, + &client.Name, + &client.Email, + &allocatedIPs, + &allowedIPs, + &extraAllowedIPs, + &client.UseServerDNS, + &client.Enabled, + &client.CreatedAt, + &client.UpdatedAt, + ); err != nil { + return clients, err + } + client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter) + client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter) + client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter) + + // generate client qrcode image in base64 + if hasQRCode && client.PrivateKey != "" { + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() + + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + if err == nil { + clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) + } else { + fmt.Print("Cannot generate QR code: ", err) + } + } + + // create the list of clients and their qrcode data + clientData.Client = &client + clients = append(clients, clientData) + } + + return clients, nil +} + +// GetClientByID func to query Clients by ID from the database +func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) { + client := model.Client{} + clientData := model.ClientData{} + var allocatedIPs string + var allowedIPs string + var extraAllowedIPs string + + // read client info + if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan( + &client.ID, + &client.PrivateKey, + &client.PublicKey, + &client.PresharedKey, + &client.Name, + &client.Email, + &allocatedIPs, + &allowedIPs, + &extraAllowedIPs, + &client.UseServerDNS, + &client.Enabled, + &client.CreatedAt, + &client.UpdatedAt, + ); err != nil { + return clientData, err + } + client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter) + client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter) + client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter) + + // generate client qrcode image in base64 + if hasQRCode && client.PrivateKey != "" { + server, _ := o.GetServer() + globalSettings, _ := o.GetGlobalSettings() + + png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) + if err == nil { + clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) + } else { + fmt.Print("Cannot generate QR code: ", err) + } + } + + clientData.Client = &client + + return clientData, nil +} + +// SaveClient func saves client to database +func (o *MySQLDB) SaveClient(client model.Client) error { + // If client doesn't exist, create a record, else update existing record + querySet := ` + SET + @id = ?, + @private_key = ?, + @public_key = ?, + @preshared_key = ?, + @name = ?, + @email = ?, + @allocated_ips = ?, + @allowed_ips = ?, + @extra_allowed_ips = ?, + @use_server_dns = ?, + @enabled = ?, + @created_at = ?, + @updated_at = ?;` + queryInsert := ` + INSERT INTO clients( + id, + private_key, + public_key, + preshared_key, + NAME, + email, + allocated_ips, + allowed_ips, + extra_allowed_ips, + use_server_dns, + enabled, + created_at, + updated_at + ) + VALUES( + @id, + @private_key, + @public_key, + @preshared_key, + @name, + @email, + @allocated_ips, + @allowed_ips, + @extra_allowed_ips, + @use_server_dns, + @enabled, + @created_at, + @updated_at + ) + ON DUPLICATE KEY + UPDATE + id = @id, + private_key = @private_key, + public_key = @public_key, + preshared_key = @preshared_key, + NAME = @name, + email = @email, + allocated_ips = @allocated_ips, + allowed_ips = @allowed_ips, + extra_allowed_ips = @extra_allowed_ips, + use_server_dns = @use_server_dns, + enabled = @enabled, + created_at = @created_at, + updated_at = @updated_at;` + + tx, err := o.conn.Begin() + if err != nil { + return err + } + // set values + if _, err := tx.Exec( + querySet, + client.ID, + client.PrivateKey, + client.PublicKey, + client.PresharedKey, + client.Name, + client.Email, + strings.Join(client.AllocatedIPs, arrayDelimiter), + strings.Join(client.AllowedIPs, arrayDelimiter), + strings.Join(client.ExtraAllowedIPs, arrayDelimiter), + client.UseServerDNS, + client.Enabled, + client.CreatedAt, + client.UpdatedAt, + ); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) + } + + return err + } + + // insert or update row + if _, err := tx.Exec(queryInsert); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr) + } + + return err + } + + return tx.Commit() +} + +// DeleteClient func deletes client from the database +func (o *MySQLDB) DeleteClient(clientID string) error { + if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil { + return err + } + + return nil +} + +// SaveServerInterface func saves a server interface to database +func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error { + // No need for ON DUPLICATE KEY UPDATE as only ever 1 record + query := ` + UPDATE + interfaces + SET + addresses = ?, + listen_port = ?, + updated_at = ?, + post_up = ?, + post_down = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + strings.Join(serverInterface.Addresses, arrayDelimiter), + serverInterface.ListenPort, + serverInterface.UpdatedAt, + serverInterface.PostUp, + serverInterface.PostDown, + ) + + return err +} + +// SaveServerKeyPair func saves a server keypair to database +func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error { + query := ` + UPDATE + keypair + SET + private_key = ?, + public_key = ?, + updated_at = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + serverKeyPair.PrivateKey, + serverKeyPair.PublicKey, + serverKeyPair.UpdatedAt, + ) + + return err +} + +// SaveGlobalSettings saves global settings to database +func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error { + query := ` + UPDATE + global_settings + SET + endpoint_address = ?, + dns_servers = ?, + mtu = ?, + persistent_keepalive = ?, + config_file_path = ?, + updated_at = ? + WHERE + id = 1;` + + _, err := o.conn.Exec( + query, + globalSettings.EndpointAddress, + strings.Join(globalSettings.DNSServers, arrayDelimiter), + globalSettings.MTU, + globalSettings.PersistentKeepalive, + globalSettings.ConfigFilePath, + globalSettings.UpdatedAt, + ) + + return err +} diff --git a/templates/mysql.sql b/templates/mysql.sql new file mode 100644 index 00000000..ffeb26f7 --- /dev/null +++ b/templates/mysql.sql @@ -0,0 +1,80 @@ +START TRANSACTION; + +CREATE TABLE `clients` ( + `id` VARCHAR(255) NOT NULL, + `private_key` VARCHAR(255) NOT NULL, + `public_key` VARCHAR(255) NOT NULL, + `preshared_key` VARCHAR(255) NOT NULL, + `name` VARCHAR(255) NOT NULL, + `email` VARCHAR(255), + `allocated_ips` VARCHAR(2550) NOT NULL, + `allowed_ips` VARCHAR(2550) NOT NULL, + `extra_allowed_ips` VARCHAR(2550), + `use_server_dns` TINYINT(1) NOT NULL, + `enabled` TINYINT(1) NOT NULL, + `created_at` DATETIME NOT NULL, + `updated_at` DATETIME NOT NULL +); + +CREATE TABLE `global_settings` ( + `id` INT(11) NOT NULL, + `endpoint_address` VARCHAR(255) NOT NULL, + `dns_servers` VARCHAR(2550) NOT NULL, + `mtu` VARCHAR(255) NOT NULL, + `persistent_keepalive` VARCHAR(255) NOT NULL, + `config_file_path` VARCHAR(255) NOT NULL, + `updated_at` DATETIME NOT NULL +); + +CREATE TABLE `interfaces` ( + `id` INT(11) NOT NULL, + `addresses` VARCHAR(2550) NOT NULL, + `listen_port` VARCHAR(5) NOT NULL, + `updated_at` DATETIME NOT NULL, + `post_up` VARCHAR(255) DEFAULT "", + `post_down` VARCHAR(255) DEFAULT "" +); + +CREATE TABLE `keypair` ( + `id` INT(11) NOT NULL, + `private_key` VARCHAR(255) NOT NULL, + `public_key` VARCHAR(255) NOT NULL, + `updated_at` DATETIME NOT NULL +); + +CREATE TABLE `users` ( + `id` INT(11) NOT NULL, + `username` VARCHAR(255) NOT NULL, + `password` VARCHAR(255) NOT NULL +); + + +ALTER TABLE `clients` + ADD PRIMARY KEY (`id`); + +ALTER TABLE `global_settings` + ADD PRIMARY KEY (`id`); + +ALTER TABLE `interfaces` + ADD PRIMARY KEY (`id`); + +ALTER TABLE `keypair` + ADD PRIMARY KEY (`id`); + +ALTER TABLE `users` + ADD PRIMARY KEY (`id`); + + +ALTER TABLE `global_settings` + MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; + +ALTER TABLE `interfaces` + MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; + +ALTER TABLE `keypair` + MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; + +ALTER TABLE `users` + MODIFY `id` INT(11) NOT NULL AUTO_INCREMENT; + +COMMIT; diff --git a/util/config.go b/util/config.go index 80cbc9c2..9dc4948e 100644 --- a/util/config.go +++ b/util/config.go @@ -16,6 +16,13 @@ var ( EmailSubject string EmailContent string SessionSecret []byte + DBType string + DBHost string + DBPort int + DBDatabase string + DBUsername string + DBPassword string + DBTLS string ) const ( diff --git a/util/util.go b/util/util.go index 7c347a96..0146f984 100644 --- a/util/util.go +++ b/util/util.go @@ -1,7 +1,6 @@ package util import ( - "encoding/json" "errors" "fmt" "net" @@ -15,7 +14,7 @@ import ( externalip "github.com/glendc/go-external-ip" "github.com/labstack/gommon/log" "github.com/ngoduykhanh/wireguard-ui/model" - "github.com/sdomino/scribble" + "github.com/ngoduykhanh/wireguard-ui/store" ) // BuildClientConfig to create wireguard client config string @@ -214,20 +213,15 @@ func GetIPFromCIDR(cidr string) (string, error) { } // GetAllocatedIPs to get all ip addresses allocated to clients and server -func GetAllocatedIPs(ignoreClientID string) ([]string, error) { +func GetAllocatedIPs(db store.IStore, ignoreClientID string) ([]string, error) { allocatedIPs := make([]string, 0) - - // initialize database directory - dir := "./db" - db, err := scribble.New(dir, nil) - if err != nil { - return nil, err - } - + // read server information serverInterface := model.ServerInterface{} - if err := db.Read("server", "interfaces", &serverInterface); err != nil { + if server, err := db.GetServer(); err != nil { return nil, err + } else{ + serverInterface = *server.Interface } // append server's addresses to the result @@ -240,7 +234,7 @@ func GetAllocatedIPs(ignoreClientID string) ([]string, error) { } // read client information - records, err := db.ReadAll("clients") + records, err := db.GetClients(false) if err != nil { return nil, err } @@ -248,9 +242,7 @@ func GetAllocatedIPs(ignoreClientID string) ([]string, error) { // append client's addresses to the result for _, f := range records { client := model.Client{} - if err := json.Unmarshal([]byte(f), &client); err != nil { - return nil, err - } + client = *f.Client if client.ID != ignoreClientID { for _, cidr := range client.AllocatedIPs {