Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/export api client #16

Merged
merged 5 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions internal/api/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package api

import (
"fmt"
"net"
"time"

"github.com/T-Systems-MMS/oc-daemon/internal/ocrunner"
"github.com/T-Systems-MMS/oc-daemon/internal/vpnstatus"
log "github.com/sirupsen/logrus"
"github.com/T-Systems-MMS/oc-daemon/pkg/vpnstatus"
)

const (
Expand All @@ -24,11 +24,11 @@ type Client struct {
}

// Request sends msg to the server and returns the server's response
func (c *Client) Request(msg *Message) *Message {
func (c *Client) Request(msg *Message) (*Message, error) {
// connect to daemon
conn, err := net.DialTimeout("unix", c.sockFile, connectTimeout)
if err != nil {
log.WithError(err).Fatal("Client dial error")
return nil, fmt.Errorf("client dial error: %w", err)
}
defer func() {
_ = conn.Close()
Expand All @@ -37,72 +37,86 @@ func (c *Client) Request(msg *Message) *Message {
// set timeout for entire request/response message exchange
deadline := time.Now().Add(clientTimeout)
if err := conn.SetDeadline(deadline); err != nil {
log.WithError(err).Fatal("Client set deadline error")
return nil, fmt.Errorf("client set deadline error: %w", err)
}

// send message to daemon
err = WriteMessage(conn, msg)
if err != nil {
log.WithError(err).Fatal("Client send message error")
return nil, fmt.Errorf("client send message error: %w", err)
}

// receive reply
reply, err := ReadMessage(conn)
if err != nil {
log.WithError(err).Fatal("Client receive message error")
return nil, fmt.Errorf("client receive message error: %w", err)
}

return reply
return reply, nil
}

// Query retrieves the VPN status from the daemon
func (c *Client) Query() *vpnstatus.Status {
// send query to daemon
func (c *Client) Query() (*vpnstatus.Status, error) {
msg := NewMessage(TypeVPNQuery, nil)
reply := c.Request(msg)
// send query to daemon

// handle response
reply, err := c.Request(msg)
if err != nil {
return nil, err
}
switch reply.Type {
case TypeOK:
// parse status in reply
status, err := vpnstatus.NewFromJSON(reply.Value)
if err != nil {
log.WithError(err).Fatal("Client received invalid status")
return nil, fmt.Errorf("client received invalid status: %w", err)
}
return status
return status, nil

case TypeError:
log.WithField("error", string(reply.Value)).Error("Client received error reply")
err := fmt.Errorf("%s", reply.Value)
return nil, fmt.Errorf("client received error reply: %w", err)
}
return nil
return nil, fmt.Errorf("client received invalid reply")
}

// Connect sends a connect request with login info to the daemon
func (c *Client) Connect(login *ocrunner.LoginInfo) {
func (c *Client) Connect(login *ocrunner.LoginInfo) error {
// convert login to json
b, err := login.JSON()
if err != nil {
log.WithError(err).Fatal("Client could not convert login info to JSON")
return fmt.Errorf("client could not convert login info to JSON: %w", err)
}

// create connect request
msg := NewMessage(TypeVPNConnect, b)

// send request to server
reply := c.Request(msg)
reply, err := c.Request(msg)
if err != nil {
return err
}
if reply.Type == TypeError {
log.WithField("error", string(reply.Value)).Error("Client received error reply")
err := fmt.Errorf("%s", reply.Value)
return fmt.Errorf("client received error reply: %w", err)
}
return nil
}

// Disconnect sends a disconnect request to the daemon
func (c *Client) Disconnect() {
func (c *Client) Disconnect() error {
// send disconnect request
msg := NewMessage(TypeVPNDisconnect, nil)
reply := c.Request(msg)
reply, err := c.Request(msg)
if err != nil {
return err
}
if reply.Type == TypeError {
log.WithField("error", string(reply.Value)).Error("Client received error reply")
err := fmt.Errorf("%s", reply.Value)
return fmt.Errorf("client received error reply: %w", err)
}
return nil
}

// NewClient returns a new Client
Expand Down
6 changes: 3 additions & 3 deletions internal/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"testing"

"github.com/T-Systems-MMS/oc-daemon/internal/ocrunner"
"github.com/T-Systems-MMS/oc-daemon/internal/vpnstatus"
"github.com/T-Systems-MMS/oc-daemon/pkg/vpnstatus"
)

// initTestClientServer returns a client an server for testing;
Expand All @@ -27,7 +27,7 @@ func initTestClientServer() (*Client, *Server) {
func TestClientRequest(t *testing.T) {
client, server := initTestClientServer()
server.Start()
reply := client.Request(NewMessage(TypeVPNQuery, nil))
reply, _ := client.Request(NewMessage(TypeVPNQuery, nil))
server.Stop()

log.Println(reply)
Expand All @@ -53,7 +53,7 @@ func TestClientQuery(t *testing.T) {
}()
server.Start()
want := status
got := client.Query()
got, _ := client.Query()
server.Stop()

log.Println(got)
Expand Down
123 changes: 38 additions & 85 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,11 @@ import (
"os"
"time"

"github.com/T-Systems-MMS/oc-daemon/internal/api"
"github.com/T-Systems-MMS/oc-daemon/internal/ocrunner"
"github.com/T-Systems-MMS/oc-daemon/pkg/client"
log "github.com/sirupsen/logrus"
)

const (
// runDir is the daemons run dir
runDir = "/run/oc-daemon"

// daemon socket file
sockFile = runDir + "/daemon.sock"

// oc runner settings
vpncScript = "/usr/bin/oc-daemon-vpncscript"

// maxReconnectTries is the maximum amount or reconnect retries
maxReconnectTries = 5
)
Expand All @@ -33,29 +23,26 @@ func readXMLProfile() []byte {
return b
}

// authenticateVPN authenticates user for vpn connection and returns login info
func authenticateVPN() *ocrunner.LoginInfo {
// authenticate
auth := ocrunner.NewAuthenticate()
auth.Certificate = config.ClientCertificate
auth.Key = config.ClientKey
auth.CA = config.CACertificate
auth.XMLProfile = xmlProfile
auth.Script = vpncScript
auth.Server = config.VPNServer
auth.User = config.User
auth.Authenticate()

return &auth.Login
}
// connectVPN connects to the VPN if necessary
func connectVPN() {
// create client
c := client.NewClient()

// authenticateConnectVPN authenticates the user and connects to the VPN
func authenticateConnectVPN(client *api.Client) {
// try to read current xml profile
pre := readXMLProfile()

// autenticate user for vpn connection
login := authenticateVPN()
// authenticate
c.ClientCertificate = config.ClientCertificate
c.ClientKey = config.ClientKey
c.CACertificate = config.CACertificate
c.XMLProfile = xmlProfile
c.VPNServer = config.VPNServer
c.User = config.User
c.Password = config.Password

if err := c.Authenticate(); err != nil {
log.WithError(err).Fatal("error authenticating user for VPN")
}

// warn user if profile changed
post := readXMLProfile()
Expand All @@ -67,88 +54,54 @@ func authenticateConnectVPN(client *api.Client) {
time.Sleep(2 * time.Second)
}

// send login info to daemon
client.Connect(login)
}

// connectVPN connects to the VPN if necessary
func connectVPN() {
// create client
client := api.NewClient(sockFile)

// get status
status := client.Query()
if status == nil {
return
// connect
if err := c.Connect(); err != nil {
log.WithError(err).Fatal("error connecting to VPN")
}

// check if we need to start the VPN connection
if status.TrustedNetwork {
log.Println("Trusted network detected, nothing to do")
return
}
if status.Connected {
log.Println("VPN already connected, nothing to do")
return
}
if status.Running {
log.Println("OpenConnect client already running, nothing to do")
return
}

// authenticate and connect
authenticateConnectVPN(client)
}

// disconnectVPN disconnects the VPN
func disconnectVPN() {
// create client
client := api.NewClient(sockFile)

// check status
status := client.Query()
if status == nil {
return
}
if !status.Running {
log.Println("OpenConnect client is not running, nothing to do")
return
}
c := client.NewClient()

// disconnect
client.Disconnect()
err := c.Disconnect()
if err != nil {
log.WithError(err).Fatal("error disconnecting from VPN")
}
}

// reconnectVPN reconnects to the VPN
func reconnectVPN() {
// create client
client := api.NewClient(sockFile)
client := client.NewClient()

// check status
status := client.Query()
if status == nil {
log.Fatal("error reconnecting to VPN")
status, err := client.Query()
if err != nil {
log.WithError(err).Fatal("error reconnecting to VPN")
}

// disconnect if needed
if status.Running {
// send disconnect request
client.Disconnect()
disconnectVPN()
}

// wait for status to switch to untrusted network and not running
try := 0
for {
status := client.Query()
if status == nil {
log.Fatal("error reconnecting to VPN")
status, err := client.Query()
if err != nil {
log.WithError(err).Fatal("error reconnecting to VPN")
}

if !status.TrustedNetwork &&
!status.Connected &&
!status.Running {
// authenticate and connect
authenticateConnectVPN(client)
connectVPN()
return
}

Expand All @@ -166,10 +119,10 @@ func reconnectVPN() {

// getStatus gets the VPN status from the daemon
func getStatus() {
client := api.NewClient(sockFile)
status := client.Query()
if status == nil {
return
c := client.NewClient()
status, err := c.Query()
if err != nil {
log.Fatal(err)
}
log.Printf("Trusted Network: %t", status.TrustedNetwork)
log.Printf("Running: %t", status.Running)
Expand Down
20 changes: 11 additions & 9 deletions internal/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ type ClientConfig struct {
CACertificate string
VPNServer string
User string
Password string
}

// empty returns if the config is empty
func (o *ClientConfig) empty() bool {
if o == nil {
func (c *ClientConfig) empty() bool {
if c == nil {
return true
}

if o.ClientCertificate == "" &&
o.ClientKey == "" &&
o.CACertificate == "" &&
o.VPNServer == "" &&
o.User == "" {
if c.ClientCertificate == "" &&
c.ClientKey == "" &&
c.CACertificate == "" &&
c.VPNServer == "" &&
c.User == "" &&
c.Password == "" {
// empty
return true
}
Expand All @@ -33,8 +35,8 @@ func (o *ClientConfig) empty() bool {
}

// save saves the config to file
func (o *ClientConfig) save(file string) {
b, err := json.MarshalIndent(o, "", " ")
func (c *ClientConfig) save(file string) {
b, err := json.MarshalIndent(c, "", " ")
if err != nil {
return
}
Expand Down
Loading