Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varied bugfixes #934

Merged
merged 10 commits into from
Jun 12, 2024
20 changes: 17 additions & 3 deletions client/incus.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ func (r *ProtocolIncus) DoHTTP(req *http.Request) (*http.Response, error) {
return r.http.Do(req)
}

// DoWebsocket performs a websocket connection, using OIDC authentication if set.
func (r *ProtocolIncus) DoWebsocket(dialer websocket.Dialer, uri string, req *http.Request) (*websocket.Conn, *http.Response, error) {
r.addClientHeaders(req)

if r.oidcClient != nil {
return r.oidcClient.dial(dialer, uri, req)
}

return dialer.Dial(uri, req.Header)
}

// addClientHeaders sets headers from client settings.
// User-Agent (if r.httpUserAgent is set).
// X-Incus-authenticated (if r.requireAuthenticated is set).
Expand Down Expand Up @@ -245,11 +256,13 @@ func (r *ProtocolIncus) rawQuery(method string, url string, data any, ETag strin
switch data := data.(type) {
case io.Reader:
// Some data to be sent along with the request
req, err = http.NewRequestWithContext(r.ctx, method, url, data)
req, err = http.NewRequestWithContext(r.ctx, method, url, io.NopCloser(data))
if err != nil {
return nil, "", err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(data), nil }

// Set the encoding accordingly
req.Header.Set("Content-Type", "application/octet-stream")
default:
Expand All @@ -267,6 +280,8 @@ func (r *ProtocolIncus) rawQuery(method string, url string, data any, ETag strin
return nil, "", err
}

req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(buf.Bytes())), nil }

// Set the encoding accordingly
req.Header.Set("Content-Type", "application/json")

Expand Down Expand Up @@ -432,10 +447,9 @@ func (r *ProtocolIncus) rawWebsocket(url string) (*websocket.Conn, error) {
// Create temporary http.Request using the http url, not the ws one, so that we can add the client headers
// for the websocket request.
req := &http.Request{URL: &r.httpBaseURL, Header: http.Header{}}
r.addClientHeaders(req)

// Establish the connection
conn, resp, err := dialer.Dial(url, req.Header)
conn, resp, err := r.DoWebsocket(dialer, url, req)
if err != nil {
if resp != nil {
_, _, err = incusParseResponse(resp)
Expand Down
49 changes: 49 additions & 0 deletions client/incus_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"syscall"
"time"

"github.com/gorilla/websocket"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
Expand Down Expand Up @@ -114,6 +115,10 @@ func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
clientID := resp.Header.Get("X-Incus-OIDC-clientid")
audience := resp.Header.Get("X-Incus-OIDC-audience")

if issuer == "" || clientID == "" {
return resp, nil
}

err = o.refresh(issuer, clientID)
if err != nil {
err = o.authenticate(issuer, clientID, audience)
Expand All @@ -125,6 +130,16 @@ func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
// Set the new access token in the header.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.tokens.AccessToken))

// Reset the request body.
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return nil, err
}

req.Body = body
}

resp, err = o.httpClient.Do(req)
if err != nil {
return nil, err
Expand All @@ -133,6 +148,40 @@ func (o *oidcClient) do(req *http.Request) (*http.Response, error) {
return resp, nil
}

// dial function executes a websocket request and handles OIDC authentication and refresh.
func (o *oidcClient) dial(dialer websocket.Dialer, uri string, req *http.Request) (*websocket.Conn, *http.Response, error) {
conn, resp, err := dialer.Dial(uri, req.Header)
if err != nil && resp == nil {
hallyn marked this conversation as resolved.
Show resolved Hide resolved
return nil, nil, err
}

// Return immediately if the error is not HTTP status unauthorized.
if conn != nil && resp.StatusCode != http.StatusUnauthorized {
return conn, resp, nil
}

issuer := resp.Header.Get("X-Incus-OIDC-issuer")
clientID := resp.Header.Get("X-Incus-OIDC-clientid")
audience := resp.Header.Get("X-Incus-OIDC-audience")

if issuer == "" || clientID == "" {
return nil, resp, err
}

err = o.refresh(issuer, clientID)
if err != nil {
err = o.authenticate(issuer, clientID, audience)
if err != nil {
return nil, resp, err
}
}

// Set the new access token in the header.
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.tokens.AccessToken))

return dialer.Dial(uri, req.Header)
}

// getProvider initializes a new OpenID Connect Relying Party for a given issuer and clientID.
// The function also creates a secure CookieHandler with random encryption and hash keys, and applies a series of configurations on the Relying Party.
func (o *oidcClient) getProvider(issuer string, clientID string) (rp.RelyingParty, error) {
Expand Down
5 changes: 5 additions & 0 deletions cmd/incus-agent/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func (r *eventsServe) String() string {
return "event handler"
}

// Code returns the HTTP code.
func (r *eventsServe) Code() int {
return http.StatusOK
}

func eventsSocket(d *Daemon, r *http.Request, w http.ResponseWriter) error {
typeStr := r.FormValue("type")
if typeStr == "" {
Expand Down
5 changes: 5 additions & 0 deletions cmd/incus-agent/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func (r *sftpServe) String() string {
return "sftp handler"
}

// Code returns the HTTP code.
func (r *sftpServe) Code() int {
return http.StatusOK
}

func (r *sftpServe) Render(w http.ResponseWriter) error {
// Upgrade to sftp.
if r.r.Header.Get("Upgrade") != "sftp" {
Expand Down
28 changes: 26 additions & 2 deletions cmd/incus/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ func (c *cmdNetworkCreate) Command() *cobra.Command {
cmd.Example = cli.FormatSection("", i18n.G(`incus network create foo
Create a new network called foo

incus network create foo < config.yaml
Create a new network called foo using the content of config.yaml.

incus network create bar network=baz --type ovn
Create a new OVN network called bar using baz as its uplink network`))

Expand All @@ -358,12 +361,27 @@ incus network create bar network=baz --type ovn
}

func (c *cmdNetworkCreate) Run(cmd *cobra.Command, args []string) error {
var stdinData api.NetworkPut

// Quick checks.
exit, err := c.global.CheckArgs(cmd, args, 1, -1)
if exit {
return err
}

// If stdin isn't a terminal, read text from it
if !termios.IsTerminal(getStdinFd()) {
contents, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}

err = yaml.Unmarshal(contents, &stdinData)
if err != nil {
return err
}
}

// Parse remote
resources, err := c.global.ParseServers(args[0])
if err != nil {
Expand All @@ -374,11 +392,17 @@ func (c *cmdNetworkCreate) Run(cmd *cobra.Command, args []string) error {
client := resource.server

// Create the network
network := api.NetworksPost{}
network := api.NetworksPost{
NetworkPut: stdinData,
}

network.Name = resource.name
network.Config = map[string]string{}
network.Type = c.network.flagType

if network.Config == nil {
network.Config = map[string]string{}
}

for i := 1; i < len(args); i++ {
entry := strings.SplitN(args[i], "=", 2)
if len(entry) < 2 {
Expand Down
5 changes: 5 additions & 0 deletions cmd/incusd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,11 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) {
resp = response.NotFound(fmt.Errorf("Method %q not found", r.Method))
}

// If sending out Forbidden, make sure we have OIDC headers.
if resp.Code() == http.StatusForbidden && d.oidcVerifier != nil {
_ = d.oidcVerifier.WriteHeaders(w)
}

// Handle errors
err = resp.Render(w)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions cmd/incusd/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ func (r *eventsServe) String() string {
return "event handler"
}

// Code returns the HTTP code.
func (r *eventsServe) Code() int {
return http.StatusOK
}

func eventsSocket(s *state.State, r *http.Request, w http.ResponseWriter) error {
// Detect project mode.
projectName := request.QueryParam(r, "project")
Expand Down
5 changes: 5 additions & 0 deletions cmd/incusd/instance_sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ func (r *sftpServeResponse) String() string {
return "sftp handler"
}

// Code returns the HTTP code.
func (r *sftpServeResponse) Code() int {
return http.StatusOK
}

func (r *sftpServeResponse) Render(w http.ResponseWriter) error {
defer func() { _ = r.instConn.Close() }()

Expand Down
17 changes: 15 additions & 2 deletions cmd/incusd/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func networksPost(d *Daemon, r *http.Request) response.Response {
}

err = s.DB.Cluster.Transaction(r.Context(), func(ctx context.Context, tx *db.ClusterTx) error {
return tx.CreatePendingNetwork(ctx, targetNode, projectName, req.Name, netType.DBType(), req.Config)
return tx.CreatePendingNetwork(ctx, targetNode, projectName, req.Name, req.Description, netType.DBType(), req.Config)
})
if err != nil {
if err == db.ErrAlreadyDefined {
Expand Down Expand Up @@ -471,7 +471,7 @@ func networksPost(d *Daemon, r *http.Request) response.Response {
for _, member := range members {
// Don't pass in any config, as these nodes don't have any node-specific
// config and we don't want to create duplicate global config.
err = tx.CreatePendingNetwork(ctx, member.Name, projectName, req.Name, netType.DBType(), nil)
err = tx.CreatePendingNetwork(ctx, member.Name, projectName, req.Name, req.Description, netType.DBType(), nil)
if err != nil && !errors.Is(err, db.ErrAlreadyDefined) {
return fmt.Errorf("Failed creating pending network for member %q: %w", member.Name, err)
}
Expand All @@ -489,6 +489,19 @@ func networksPost(d *Daemon, r *http.Request) response.Response {
return response.SmartError(err)
}

n, err := network.LoadByName(s, projectName, req.Name)
if err != nil {
return response.SmartError(fmt.Errorf("Failed loading network: %w", err))
}

err = s.Authorizer.AddNetwork(r.Context(), projectName, req.Name)
if err != nil {
logger.Error("Failed to add network to authorizer", logger.Ctx{"name": req.Name, "project": projectName, "error": err})
hallyn marked this conversation as resolved.
Show resolved Hide resolved
}

requestor := request.CreateRequestor(r)
s.Events.SendLifecycle(projectName, lifecycle.NetworkCreated.Event(n, requestor, nil))

return resp
}

Expand Down
5 changes: 5 additions & 0 deletions cmd/incusd/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,11 @@ func (r *operationWebSocket) String() string {
return md.ID
}

// Code returns the HTTP code.
func (r *operationWebSocket) Code() int {
return http.StatusOK
}

// swagger:operation GET /1.0/operations/{id}/websocket?public operations operation_websocket_get_untrusted
//
// Get the websocket stream
Expand Down
4 changes: 2 additions & 2 deletions internal/server/db/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ WHERE networks.id = ? AND networks.state = ?
}

// CreatePendingNetwork creates a new pending network on the node with the given name.
func (c *ClusterTx) CreatePendingNetwork(ctx context.Context, node string, projectName string, name string, netType NetworkType, conf map[string]string) error {
func (c *ClusterTx) CreatePendingNetwork(ctx context.Context, node string, projectName string, name string, description string, netType NetworkType, conf map[string]string) error {
// First check if a network with the given name exists, and, if so, that it's in the pending state.
network := struct {
id int64
Expand Down Expand Up @@ -325,7 +325,7 @@ func (c *ClusterTx) CreatePendingNetwork(ctx context.Context, node string, proje

// No existing network with the given name was found, let's create one.
columns := []string{"project_id", "name", "type", "description"}
values := []any{projectID, name, netType, ""}
values := []any{projectID, name, netType, description}
networkID, err = query.UpsertObject(c.tx, "networks", columns, values)
if err != nil {
return err
Expand Down
12 changes: 6 additions & 6 deletions internal/server/db/networks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,23 @@ func TestCreatePendingNetwork(t *testing.T) {
require.NoError(t, err)

config := map[string]string{"bridge.external_interfaces": "foo"}
err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, config)
err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, config)
require.NoError(t, err)

networkID, err := tx.GetNetworkID(context.Background(), api.ProjectDefaultName, "network1")
require.NoError(t, err)
assert.True(t, networkID > 0)

config = map[string]string{"bridge.external_interfaces": "bar"}
err = tx.CreatePendingNetwork(context.Background(), "rusp", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, config)
err = tx.CreatePendingNetwork(context.Background(), "rusp", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, config)
require.NoError(t, err)

// The initial node (whose name is 'none' by default) is missing.
_, err = tx.NetworkNodeConfigs(context.Background(), networkID)
require.EqualError(t, err, "Network not defined on nodes: none")

config = map[string]string{"bridge.external_interfaces": "egg,if1/eth0/1001"}
err = tx.CreatePendingNetwork(context.Background(), "none", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, config)
err = tx.CreatePendingNetwork(context.Background(), "none", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, config)
require.NoError(t, err)

// Now the storage is defined on all nodes.
Expand All @@ -90,10 +90,10 @@ func TestNetworksCreatePending_AlreadyDefined(t *testing.T) {
_, err := tx.CreateNode("buzz", "1.2.3.4:666")
require.NoError(t, err)

err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, map[string]string{})
err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, map[string]string{})
require.NoError(t, err)

err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, map[string]string{})
err = tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, map[string]string{})
require.Equal(t, db.ErrAlreadyDefined, err)
}

Expand All @@ -102,6 +102,6 @@ func TestNetworksCreatePending_NonExistingNode(t *testing.T) {
tx, cleanup := db.NewTestClusterTx(t)
defer cleanup()

err := tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", db.NetworkTypeBridge, map[string]string{})
err := tx.CreatePendingNetwork(context.Background(), "buzz", api.ProjectDefaultName, "network1", "", db.NetworkTypeBridge, map[string]string{})
require.True(t, response.IsNotFoundError(err))
}
Loading
Loading