Skip to content

Commit

Permalink
feat(ws): add auth to backend (#202)
Browse files Browse the repository at this point in the history
* feat(ws): add auth to backend

Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com>

* add `DISABLE_AUTH` for interim testing (enabled by default)

Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com>

---------

Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com>
  • Loading branch information
thesuperzapper authored Feb 11, 2025
1 parent 4cbc26e commit bc6f311
Show file tree
Hide file tree
Showing 18 changed files with 725 additions and 134 deletions.
12 changes: 10 additions & 2 deletions workspaces/backend/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (

"github.com/julienschmidt/httprouter"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authorization/authorizer"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeflow/notebooks/workspaces/backend/internal/config"
Expand Down Expand Up @@ -52,20 +54,26 @@ const (
)

type App struct {
Config config.EnvConfig
Config *config.EnvConfig
logger *slog.Logger
repositories *repositories.Repositories
Scheme *runtime.Scheme
RequestAuthN authenticator.Request
RequestAuthZ authorizer.Authorizer
}

// NewApp creates a new instance of the app
func NewApp(cfg config.EnvConfig, logger *slog.Logger, cl client.Client, scheme *runtime.Scheme) (*App, error) {
func NewApp(cfg *config.EnvConfig, logger *slog.Logger, cl client.Client, scheme *runtime.Scheme, reqAuthN authenticator.Request, reqAuthZ authorizer.Authorizer) (*App, error) {

// TODO: log the configuration on startup

app := &App{
Config: cfg,
logger: logger,
repositories: repositories.NewRepositories(cl),
Scheme: scheme,
RequestAuthN: reqAuthN,
RequestAuthZ: reqAuthZ,
}
return app, nil
}
Expand Down
73 changes: 73 additions & 0 deletions workspaces/backend/api/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
Copyright 2024.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package api

import (
"fmt"
"net/http"

"k8s.io/apiserver/pkg/authorization/authorizer"

"github.com/kubeflow/notebooks/workspaces/backend/internal/auth"
)

// requireAuth verifies that the request is authenticated and authorized to take the actions specified by the given policies.
// If this method returns false, the request has been handled and the caller should return immediately.
// If this method returns true, the request is authenticated and authorized to proceed.
// This method should only be called once per request.
func (a *App) requireAuth(w http.ResponseWriter, r *http.Request, policies []*auth.ResourcePolicy) bool {
ctx := r.Context()

// if auth is disabled, allow the request to proceed
if a.Config.DisableAuth {
return true
}

// authenticate the request (extract user and groups from the request headers)
res, ok, err := a.RequestAuthN.AuthenticateRequest(r)
if err != nil {
err = fmt.Errorf("failed to authenticate request: %w", err)
a.serverErrorResponse(w, r, err)
return false
}
if !ok {
a.unauthorizedResponse(w, r)
return false
}

// for each policy, check if the user is authorized to take the requested action
for _, policy := range policies {
attributes := policy.AttributesFor(res.User)
authorized, reason, err := a.RequestAuthZ.Authorize(ctx, attributes)
if err != nil {
err = fmt.Errorf("failed to authorize request for user %q: %w", res.User.GetName(), err)
a.serverErrorResponse(w, r, err)
return false
}

if authorized != authorizer.DecisionAllow {
msg := fmt.Sprintf("authorization was denied for user %q", res.User.GetName())
if reason != "" {
msg = fmt.Sprintf("%s: %s", msg, reason)
}
a.forbiddenResponse(w, r, msg)
return false
}
}

return true
}
33 changes: 33 additions & 0 deletions workspaces/backend/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func (a *App) LogError(r *http.Request, err error) {
a.logger.Error(err.Error(), "method", method, "uri", uri)
}

func (a *App) LogWarn(r *http.Request, message string) {
var (
method = r.Method
uri = r.URL.RequestURI()
)

a.logger.Warn(message, "method", method, "uri", uri)
}

//nolint:unused
func (a *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
httpError := &HTTPError{
Expand Down Expand Up @@ -103,6 +112,30 @@ func (a *App) methodNotAllowedResponse(w http.ResponseWriter, r *http.Request) {
a.errorResponse(w, r, httpError)
}

func (a *App) unauthorizedResponse(w http.ResponseWriter, r *http.Request) {
httpError := &HTTPError{
StatusCode: http.StatusUnauthorized,
ErrorResponse: ErrorResponse{
Code: strconv.Itoa(http.StatusUnauthorized),
Message: "authentication is required to access this resource",
},
}
a.errorResponse(w, r, httpError)
}

func (a *App) forbiddenResponse(w http.ResponseWriter, r *http.Request, msg string) {
a.LogWarn(r, msg)

httpError := &HTTPError{
StatusCode: http.StatusForbidden,
ErrorResponse: ErrorResponse{
Code: strconv.Itoa(http.StatusForbidden),
Message: "you are not authorized to access this resource",
},
}
a.errorResponse(w, r, httpError)
}

//nolint:unused
func (a *App) failedValidationResponse(w http.ResponseWriter, r *http.Request, errors map[string]string) {
message, err := json.Marshal(errors)
Expand Down
15 changes: 0 additions & 15 deletions workspaces/backend/api/healthcheck_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,13 @@ import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/kubeflow/notebooks/workspaces/backend/internal/config"
models "github.com/kubeflow/notebooks/workspaces/backend/internal/models/health_check"
"github.com/kubeflow/notebooks/workspaces/backend/internal/repositories"
)

var _ = Describe("HealthCheck Handler", func() {
var (
a App
)

Context("when backend is healthy", func() {

BeforeEach(func() {
repos := repositories.NewRepositories(k8sClient)
a = App{
Config: config.EnvConfig{
Port: 4000,
},
repositories: repos,
}
})

It("should return a health check response", func() {
By("creating the HTTP request")
req, err := http.NewRequest(http.MethodGet, HealthCheckPath, http.NoBody)
Expand Down
14 changes: 14 additions & 0 deletions workspaces/backend/api/namespaces_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,28 @@ import (
"net/http"

"github.com/julienschmidt/httprouter"
corev1 "k8s.io/api/core/v1"

"github.com/kubeflow/notebooks/workspaces/backend/internal/auth"
models "github.com/kubeflow/notebooks/workspaces/backend/internal/models/namespaces"
)

type NamespacesEnvelope Envelope[[]models.Namespace]

func (a *App) GetNamespacesHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {

// =========================== AUTH ===========================
authPolicies := []*auth.ResourcePolicy{
auth.NewResourcePolicy(
auth.ResourceVerbList,
&corev1.Namespace{},
),
}
if success := a.requireAuth(w, r, authPolicies); !success {
return
}
// ============================================================

namespaces, err := a.repositories.Namespace.GetNamespaces(r.Context())
if err != nil {
a.serverErrorResponse(w, r, err)
Expand Down
16 changes: 3 additions & 13 deletions workspaces/backend/api/namespaces_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,10 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

"github.com/kubeflow/notebooks/workspaces/backend/internal/config"
models "github.com/kubeflow/notebooks/workspaces/backend/internal/models/namespaces"
"github.com/kubeflow/notebooks/workspaces/backend/internal/repositories"
)

var _ = Describe("Namespaces Handler", func() {
var (
a App
)

// NOTE: these tests assume a specific state of the cluster, so cannot be run in parallel with other tests.
// therefore, we run them using the `Serial` Ginkgo decorators.
Expand All @@ -47,14 +42,6 @@ var _ = Describe("Namespaces Handler", func() {
const namespaceName2 = "get-ns-test-ns2"

BeforeEach(func() {
repos := repositories.NewRepositories(k8sClient)
a = App{
Config: config.EnvConfig{
Port: 4000,
},
repositories: repos,
}

By("creating Namespace 1")
namespace1 := &corev1.Namespace{
ObjectMeta: metav1.ObjectMeta{
Expand Down Expand Up @@ -95,6 +82,9 @@ var _ = Describe("Namespaces Handler", func() {
req, err := http.NewRequest(http.MethodGet, AllNamespacesPath, http.NoBody)
Expect(err).NotTo(HaveOccurred())

By("setting the auth headers")
req.Header.Set(userIdHeader, adminUser)

By("executing GetNamespacesHandler")
ps := httprouter.Params{}
rr := httptest.NewRecorder()
Expand Down
68 changes: 60 additions & 8 deletions workspaces/backend/api/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,52 @@ package api
import (
"context"
"fmt"
"log/slog"
"path/filepath"
"runtime"
"testing"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

kubefloworgv1beta1 "github.com/kubeflow/notebooks/workspaces/controller/api/v1beta1"
v1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/scheme"
"k8s.io/client-go/rest"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/envtest"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"

kubefloworgv1beta1 "github.com/kubeflow/notebooks/workspaces/controller/api/v1beta1"
"github.com/kubeflow/notebooks/workspaces/backend/internal/auth"
"github.com/kubeflow/notebooks/workspaces/backend/internal/config"
)

// These tests use Ginkgo (BDD-style Go testing framework). Refer to
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.

const (
userIdHeader = "userid-header"
userIdPrefix = ""
groupsHeader = "groups-header"

adminUser = "notebooks-admin"
)

var (
testEnv *envtest.Environment
cfg *rest.Config

k8sClient client.Client

a *App

ctx context.Context
cancel context.CancelFunc
)
Expand Down Expand Up @@ -95,6 +108,30 @@ var _ = BeforeSuite(func() {
Expect(err).NotTo(HaveOccurred())
Expect(k8sClient).NotTo(BeNil())

By("creating the notebooks-admin ClusterRoleBinding")
Expect(k8sClient.Create(ctx, &rbacv1.ClusterRoleBinding{
ObjectMeta: metav1.ObjectMeta{
Name: "notebooks-admin",
},
Subjects: []rbacv1.Subject{
{
Kind: "User",
Name: adminUser,
},
},
RoleRef: rbacv1.RoleRef{
Kind: "ClusterRole",
Name: "cluster-admin",
},
})).To(Succeed())

By("listing the clusterRoles")
clusterRoles := &rbacv1.ClusterRoleList{}
Expect(k8sClient.List(ctx, clusterRoles)).To(Succeed())
for _, clusterRole := range clusterRoles.Items {
fmt.Printf("ClusterRole: %s\n", clusterRole.Name)
}

By("setting up the controller manager")
k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{
Scheme: scheme.Scheme,
Expand All @@ -104,6 +141,21 @@ var _ = BeforeSuite(func() {
})
Expect(err).NotTo(HaveOccurred())

By("initializing the application logger")
appLogger := slog.New(slog.NewTextHandler(GinkgoWriter, nil))

By("creating the request authenticator")
reqAuthN, err := auth.NewRequestAuthenticator(userIdHeader, userIdPrefix, groupsHeader)
Expect(err).NotTo(HaveOccurred())

By("creating the request authorizer")
reqAuthZ, err := auth.NewRequestAuthorizer(k8sManager.GetConfig(), k8sManager.GetHTTPClient())
Expect(err).NotTo(HaveOccurred())

By("creating the application")
// NOTE: we use the `k8sClient` rather than `k8sManager.GetClient()` to avoid race conditions with the cached client
a, err = NewApp(&config.EnvConfig{}, appLogger, k8sClient, k8sManager.GetScheme(), reqAuthN, reqAuthZ)

go func() {
defer GinkgoRecover()
err = k8sManager.Start(ctx)
Expand Down
Loading

0 comments on commit bc6f311

Please sign in to comment.