From 9b276fd77d13467bdc6769271c84ed4eb2825f6f Mon Sep 17 00:00:00 2001 From: Kevin Conner Date: Wed, 18 Sep 2024 15:22:49 -0700 Subject: [PATCH] Support for JWT Authorization and token refresh Signed-off-by: Kevin Conner --- .github/workflows/docker.yaml | 2 + Makefile | 19 +- charts/zora/README.md | 20 + charts/zora/templates/_helpers.tpl | 23 ++ charts/zora/templates/hooks/install.yaml | 1 + .../zora/templates/operator/deployment.yaml | 14 +- .../templates/tokenrefresh/deployment.yaml | 93 +++++ charts/zora/templates/tokenrefresh/rbac.yaml | 53 +++ charts/zora/values.yaml | 51 +++ cmd/main.go | 8 +- cmd/tokenrefresh/Dockerfile | 34 ++ cmd/tokenrefresh/main.go | 352 ++++++++++++++++++ go.mod | 3 +- go.sum | 1 + internal/saas/client.go | 44 ++- pkg/authentication/tokendata.go | 73 ++++ pkg/filemonitor/filemonitor.go | 293 +++++++++++++++ pkg/filemonitor/filemonitor_linux_test.go | 113 ++++++ pkg/filemonitor/filemonitor_test.go | 226 +++++++++++ 19 files changed, 1406 insertions(+), 17 deletions(-) create mode 100644 charts/zora/templates/tokenrefresh/deployment.yaml create mode 100644 charts/zora/templates/tokenrefresh/rbac.yaml create mode 100644 cmd/tokenrefresh/Dockerfile create mode 100644 cmd/tokenrefresh/main.go create mode 100644 pkg/authentication/tokendata.go create mode 100644 pkg/filemonitor/filemonitor.go create mode 100644 pkg/filemonitor/filemonitor_linux_test.go create mode 100644 pkg/filemonitor/filemonitor_test.go diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index bed8362c..379af6f3 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -22,6 +22,8 @@ jobs: image: operator - dockerfile: cmd/worker/Dockerfile image: worker + - dockerfile: cmd/refreshtoken/Dockerfile + image: refreshtoken steps: - name: checkout uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index ad1983bc..ffd02c99 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ # Image URL to use all building/pushing image targets IMG ?= operator:latest WORKER_IMG ?= worker:latest +TOKENREFRESH_IMG ?= tokenrefresh:latest # ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary. ENVTEST_K8S_VERSION = 1.29.3 @@ -97,9 +98,10 @@ lint-fix: golangci-lint ## Run golangci-lint linter and perform fixes ##@ Build .PHONY: build -build: manifests generate fmt vet ## Build manager and worker binaries. +build: manifests generate fmt vet ## Build manager, worker and tokenrefresh binaries. go build -o bin/manager cmd/main.go go build -o bin/worker cmd/worker/main.go + go build -o bin/tokenrefresh cmd/tokenrefresh/main.go .PHONY: run run: manifests generate fmt vet ## Run a controller from your host. @@ -116,6 +118,10 @@ docker-build: test ## Build docker image with the manager. docker-build-worker: test ## Build docker image with worker. $(CONTAINER_TOOL) build -t ${WORKER_IMG} -f cmd/worker/Dockerfile . +.PHONY: docker-build-tokenrefresh +docker-build-tokenrefresh: test ## Build docker image with tokenrefresh. + $(CONTAINER_TOOL) build -t ${TOKENREFRESH_IMG} -f cmd/tokenrefresh/Dockerfile . + .PHONY: docker-push docker-push: ## Push docker image with the manager. $(CONTAINER_TOOL) push ${IMG} @@ -124,6 +130,10 @@ docker-push: ## Push docker image with the manager. docker-push-worker: ## Push docker image with worker. $(CONTAINER_TOOL) push ${WORKER_IMG} +.PHONY: docker-push-tokenrefresh +docker-push-tokenrefresh: ## Push docker image with tokenrefresh. + $(CONTAINER_TOOL) push ${TOKENREFRESH_IMG} + # PLATFORMS defines the target platforms for the manager image be build to provide support to multiple # architectures. (i.e. make docker-buildx IMG=myregistry/myoperator:0.0.1). To use this option you need to: # - able to use docker buildx . More info: https://docs.docker.com/build/buildx/ @@ -191,9 +201,10 @@ kind-create-cluster: kind ## Create a local Kubernetes cluster with Kind $(KIND) create cluster --name $(CLUSTER_NAME) .PHONY: kind-load-images -kind-load-images: kind docker-build docker-build-worker ## Build and load docker images into Kind nodes - $(KIND) load docker-image ${IMG} - $(KIND) load docker-image ${WORKER_IMG} +kind-load-images: kind docker-build docker-build-worker docker-build-tokenrefresh ## Build and load docker images into Kind nodes + $(KIND) load docker-image -n $(CLUSTER_NAME) ${IMG} + $(KIND) load docker-image -n $(CLUSTER_NAME) ${WORKER_IMG} + $(KIND) load docker-image -n $(CLUSTER_NAME) ${TOKENREFRESH_IMG} ##@ Build Dependencies diff --git a/charts/zora/README.md b/charts/zora/README.md index 56fc902b..d05dba18 100644 --- a/charts/zora/README.md +++ b/charts/zora/README.md @@ -143,6 +143,26 @@ The following table lists the configurable parameters of the Zora chart and thei | httpsProxy | string | `""` | HTTPS proxy URL | | noProxy | string | `"kubernetes.default.svc.*,127.0.0.1,localhost"` | Comma-separated list of URL patterns to be excluded from going through the proxy | | updateCRDs | bool | `true` for upgrades | Specifies whether CRDs should be updated by operator at startup | +| tokenRefresh.image.repository | string | `"ghcr.io/undistro/zora/tokenrefresh"` | tokenrefresh image repository | +| tokenRefresh.image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion | +| tokenRefresh.image.pullPolicy | string | `"IfNotPresent"` | Image pull policy | +| tokenRefresh.rbac.create | bool | `true` | Specifies whether Roles and RoleBindings should be created | +| tokenRefresh.rbac.serviceAccount.create | bool | `true` | Specifies whether a service account should be created | +| tokenRefresh.rbac.serviceAccount.annotations | object | `{}` | Annotations to be added to service account | +| tokenRefresh.rbac.serviceAccount.name | string | `""` | The name of the service account to use. If not set and create is true, a name is generated using the fullname template | +| tokenRefresh.minRefreshTime | string | `"1m"` | Minimum time to wait before checking for token refresh | +| tokenRefresh.refreshThreshold | string | `"2h"` | Threshold relative to the token expiry timestamp, after which a token can be refreshed. | +| tokenRefresh.nodeSelector | object | `{}` | [Node selection](https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node) to constrain a Pod to only be able to run on particular Node(s) | +| tokenRefresh.tolerations | list | `[]` | [Tolerations](https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration) for pod assignment | +| tokenRefresh.affinity | object | `{}` | Map of node/pod [affinities](https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration) | +| tokenRefresh.podAnnotations | object | `{"kubectl.kubernetes.io/default-container":"manager"}` | Annotations to be added to pods | +| tokenRefresh.podSecurityContext | object | `{"runAsNonRoot":true}` | [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context) to add to the pod | +| tokenRefresh.securityContext | object | `{"allowPrivilegeEscalation":false,"readOnlyRootFilesystem":true}` | [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context) to add to `manager` container | +| zoraauth.domain | string | `""` | The domain associated with the tokens | +| zoraauth.clientId | string | `""` | The client id associated with the tokens | +| zoraauth.accessToken | string | `""` | The access token authorizing access to the SaaS API server | +| zoraauth.tokenType | string | `"Bearer"` | The type of the access token | +| zoraauth.refreshToken | string | `""` | The refresh token for obtaining a new access token | Specify each parameter using the `--set key=value[,key=value]` argument to `helm install`. For example, diff --git a/charts/zora/templates/_helpers.tpl b/charts/zora/templates/_helpers.tpl index 0c23f8a7..7a3ebde7 100644 --- a/charts/zora/templates/_helpers.tpl +++ b/charts/zora/templates/_helpers.tpl @@ -77,6 +77,25 @@ Create the name of the service account to use in Operator {{- end }} {{- end }} +{{/* +TokenRefresh selector labels +*/}} +{{- define "zora.tokenRefreshSelectorLabels" -}} +{{ include "zora.selectorLabels" . }} +app.kubernetes.io/component: token-refresh +{{- end }} + +{{/* +Create the name of the service account to use in TokenRefresh +*/}} +{{- define "zora.tokenRefreshServiceAccountName" -}} +{{- if .Values.tokenRefresh.rbac.serviceAccount.create }} +{{- default (printf "%s-%s" (include "zora.fullname" .) "token-refresh") .Values.tokenRefresh.rbac.serviceAccount.name }} +{{- else }} +{{- default "default" .Values.tokenRefresh.rbac.serviceAccount.name }} +{{- end }} +{{- end }} + {{- define "zora.imagePullSecret" }} {{- with .Values.imageCredentials }} {{- printf "{\"auths\":{\"%s\":{\"auth\":\"%s\"}}}" .registry (printf "%s:%s" .username .password | b64enc) | b64enc }} @@ -149,3 +168,7 @@ Truncate a name to a specific length {{- $isHourBad := not (mustRegexMatch "^(?:\\d|[0-5]\\d)$" $hour) -}} {{- or $isMinuteBad $isHourBad -}} {{- end -}} + +{{- define "zora.saasTokenSecretName" -}} +{{- printf "%s-saas-tokens" (include "zora.fullname" .) -}} +{{- end }} diff --git a/charts/zora/templates/hooks/install.yaml b/charts/zora/templates/hooks/install.yaml index a75315f9..3b757cf5 100644 --- a/charts/zora/templates/hooks/install.yaml +++ b/charts/zora/templates/hooks/install.yaml @@ -37,6 +37,7 @@ spec: - | curl -kfsS -X POST '{{ tpl .Values.saas.installURL . }}' \ -H 'content-type: application/json' \ + -H 'Authorization: {{ .Values.zoraauth.tokenType }} {{ .Values.zoraauth.accessToken }}' \ {{- if .Values.httpsProxy }} -x '{{ .Values.httpsProxy}}' \ {{- end }} diff --git a/charts/zora/templates/operator/deployment.yaml b/charts/zora/templates/operator/deployment.yaml index 576aea5c..46c0ea61 100644 --- a/charts/zora/templates/operator/deployment.yaml +++ b/charts/zora/templates/operator/deployment.yaml @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. {{ $secretName := printf "%s-serving-cert" (include "zora.fullname" .) -}} +{{ $saasTokensSecretName := (include "zora.saasTokenSecretName" .) -}} {{- $serviceName := printf "%s-webhook" (include "zora.fullname" .) -}} {{- if .Values.operator.webhook.enabled -}} {{- $existingSecret := lookup "v1" "Secret" .Release.Namespace $secretName -}} @@ -118,6 +119,7 @@ spec: - --inject-conversion={{ .Values.operator.webhook.enabled }} - --webhook-service-name={{ $serviceName }} - --webhook-service-namespace={{ .Release.Namespace }} + - --token-path=/tmp/jwt-tokens/token image: "{{ .Values.operator.image.repository }}:{{ .Values.operator.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.operator.image.pullPolicy }} ports: @@ -131,11 +133,16 @@ spec: - containerPort: 9443 name: webhook-server protocol: TCP + {{- end }} volumeMounts: + {{- if .Values.operator.webhook.enabled }} - mountPath: /tmp/k8s-webhook-server/serving-certs name: cert readOnly: true {{- end }} + - mountPath: /tmp/jwt-tokens + name: jwt-tokens + readOnly: true livenessProbe: httpGet: path: /healthz @@ -152,14 +159,19 @@ spec: {{- toYaml .Values.operator.resources | nindent 12 }} securityContext: {{- toYaml .Values.operator.securityContext | nindent 12 }} - {{- if .Values.operator.webhook.enabled }} volumes: + {{- if .Values.operator.webhook.enabled }} - name: cert secret: defaultMode: 420 secretName: {{ $secretName }} optional: true {{- end }} + - name: jwt-tokens + secret: + defaultMode: 420 + secretName: {{ $saasTokensSecretName }} + optional: true securityContext: {{- toYaml .Values.operator.podSecurityContext | nindent 8 }} serviceAccountName: {{ include "zora.operatorServiceAccountName" . }} diff --git a/charts/zora/templates/tokenrefresh/deployment.yaml b/charts/zora/templates/tokenrefresh/deployment.yaml new file mode 100644 index 00000000..b82781e2 --- /dev/null +++ b/charts/zora/templates/tokenrefresh/deployment.yaml @@ -0,0 +1,93 @@ +# Copyright 2022 Undistro Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +{{- if .Values.saas.workspaceID -}} +{{ $secretName := (include "zora.saasTokenSecretName" .) -}} +{{- $existingSecret := lookup "v1" "Secret" .Release.Namespace $secretName -}} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} +type: undistro.io/jwtTokens +data: +{{- if $existingSecret }} + {{- toYaml $existingSecret.data | nindent 2 }} +{{- else }} + token: {{ printf "{ \"access_token\": \"%s\", \"refresh_token\": \"%s\", \"token_type\": \"%s\" }" .Values.zoraauth.accessToken .Values.zoraauth.refreshToken .Values.zoraauth.tokenType | b64enc }} +{{- end }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "zora.fullname" . }}-tokenrefresh + labels: + {{- include "zora.tokenRefreshSelectorLabels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "zora.tokenRefreshSelectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.tokenRefresh.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "zora.tokenRefreshSelectorLabels" . | nindent 8 }} + spec: + containers: + - name: tokenrefresh + {{- if .Values.httpsProxy }} + env: + - name: HTTPS_PROXY + value: {{ .Values.httpsProxy | quote }} + - name: NO_PROXY + value: {{ .Values.noProxy | quote }} + {{- end }} + command: + - /tokenrefresh + args: + - --secret-name={{ $secretName }} + - --namespace={{ .Release.Namespace }} + - --domain={{ .Values.zoraauth.domain }} + - --client-id={{ .Values.zoraauth.clientId }} + - --min-refresh-time={{ .Values.tokenRefresh.minRefreshTime }} + - --refresh-threshold={{ .Values.tokenRefresh.refreshThreshold }} + image: "{{ .Values.tokenRefresh.image.repository }}:{{ .Values.tokenRefresh.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.tokenRefresh.image.pullPolicy }} + resources: + {{- toYaml .Values.tokenRefresh.resources | nindent 12 }} + securityContext: + {{- toYaml .Values.tokenRefresh.securityContext | nindent 12 }} + volumes: + - name: jwt-tokens + secret: + defaultMode: 420 + secretName: {{ $secretName }} + securityContext: + {{- toYaml .Values.tokenRefresh.podSecurityContext | nindent 8 }} + serviceAccountName: {{ include "zora.tokenRefreshServiceAccountName" . }} + terminationGracePeriodSeconds: 10 + {{- with .Values.tokenRefresh.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tokenRefresh.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tokenRefresh.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end -}} diff --git a/charts/zora/templates/tokenrefresh/rbac.yaml b/charts/zora/templates/tokenrefresh/rbac.yaml new file mode 100644 index 00000000..a17482ce --- /dev/null +++ b/charts/zora/templates/tokenrefresh/rbac.yaml @@ -0,0 +1,53 @@ +# Copyright 2023 Undistro Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{{- if .Values.saas.workspaceID -}} +{{ if .Values.tokenRefresh.rbac.serviceAccount.create -}} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ include "zora.tokenRefreshServiceAccountName" . }} + labels: + {{- include "zora.tokenRefreshSelectorLabels" . | nindent 4 }} + {{- with .Values.tokenRefresh.rbac.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +{{ end }} +{{- if .Values.tokenRefresh.rbac.create -}} +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ include "zora.tokenRefreshServiceAccountName" . }}-secret-access +rules: +- apiGroups: [""] + resources: ["secrets"] + verbs: ["get", "list", "watch", "update", "patch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ include "zora.tokenRefreshServiceAccountName" . }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ include "zora.tokenRefreshServiceAccountName" . }}-secret-access +subjects: +- kind: ServiceAccount + name: {{ include "zora.tokenRefreshServiceAccountName" . }} + namespace: {{ .Release.Namespace }} +{{- end }} +{{- end -}} +--- diff --git a/charts/zora/values.yaml b/charts/zora/values.yaml index d0e937e8..dec1926d 100644 --- a/charts/zora/values.yaml +++ b/charts/zora/values.yaml @@ -305,3 +305,54 @@ noProxy: kubernetes.default.svc.*,127.0.0.1,localhost # -- (bool) Specifies whether CRDs should be updated by operator at startup # @default -- `true` for upgrades updateCRDs: + +tokenRefresh: + image: + # -- tokenrefresh image repository + repository: ghcr.io/undistro/zora/tokenrefresh + # -- Overrides the image tag whose default is the chart appVersion + tag: "" + # -- Image pull policy + pullPolicy: IfNotPresent + rbac: + # -- Specifies whether Roles and RoleBindings should be created + create: true + serviceAccount: + # -- Specifies whether a service account should be created + create: true + # -- Annotations to be added to service account + annotations: {} + # -- The name of the service account to use. If not set and create is true, a name is generated using the fullname template + name: "" + # -- Minimum time to wait before checking for token refresh + minRefreshTime: "1m" + # -- Threshold relative to the token expiry timestamp, after which a token can be refreshed. + refreshThreshold: "2h" + # -- [Node selection](https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node) to constrain a Pod to only be able to run on particular Node(s) + nodeSelector: {} + # -- [Tolerations](https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration) for pod assignment + tolerations: [] + # -- Map of node/pod [affinities](https://kubernetes.io/docs/concepts/scheduling-eviction/taint-and-toleration) + affinity: {} + # -- Annotations to be added to pods + podAnnotations: + kubectl.kubernetes.io/default-container: manager + # -- [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context) to add to the pod + podSecurityContext: + runAsNonRoot: true + # -- [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context) to add to `manager` container + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + +zoraauth: + # -- The domain associated with the tokens + domain: "" + # -- The client id associated with the tokens + clientId: "" + # -- The access token authorizing access to the SaaS API server + accessToken: "" + # -- The type of the access token + tokenType: "Bearer" + # -- The refresh token for obtaining a new access token + refreshToken: "" diff --git a/cmd/main.go b/cmd/main.go index 65891f9d..72488a89 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -88,6 +88,7 @@ func main() { var webhookServiceName string var webhookServiceNamespace string var webhookServicePath string + var tokenPath string flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") @@ -124,6 +125,11 @@ func main() { "Webhook service namespace") flag.StringVar(&webhookServicePath, "webhook-service-path", "/convert", "URL path for webhook conversion") + flag.StringVar(&tokenPath, "token-path", "/tmp/jwt-tokens/tokens", + "URL path for authorization tokens") + + done := make(chan struct{}) + defer close(done) opts := zap.Options{ Development: true, @@ -175,7 +181,7 @@ func main() { var onClusterScanUpdate, onClusterScanDelete saas.ClusterScanHook client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyFromEnvironment}} if saasWorkspaceID != "" { - saasClient, err := saas.NewClient(saasServer, version, saasWorkspaceID, client) + saasClient, err := saas.NewClient(saasServer, version, saasWorkspaceID, client, tokenPath, done) if err != nil { setupLog.Error(err, "unable to create SaaS client", "workspaceID", saasWorkspaceID) os.Exit(1) diff --git a/cmd/tokenrefresh/Dockerfile b/cmd/tokenrefresh/Dockerfile new file mode 100644 index 00000000..48f84d35 --- /dev/null +++ b/cmd/tokenrefresh/Dockerfile @@ -0,0 +1,34 @@ +# Copyright 2022 Undistro Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM golang:1.22 AS builder +ARG TARGETOS +ARG TARGETARCH + +WORKDIR /workspace +COPY go.mod go.mod +COPY go.sum go.sum +RUN go mod download + +COPY cmd/tokenrefresh/main.go cmd/tokenrefresh/main.go +COPY pkg/ pkg/ + +RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o tokenrefresh cmd/tokenrefresh/main.go + +FROM gcr.io/distroless/static:nonroot +WORKDIR / +COPY --from=builder /workspace/tokenrefresh . +USER 65532:65532 + +ENTRYPOINT ["/tokenrefresh"] diff --git a/cmd/tokenrefresh/main.go b/cmd/tokenrefresh/main.go new file mode 100644 index 00000000..9560da6f --- /dev/null +++ b/cmd/tokenrefresh/main.go @@ -0,0 +1,352 @@ +// Copyright 2024 Undistro Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/signal" + "reflect" + "sync" + "syscall" + "time" + + "github.com/undistro/zora/pkg/authentication" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/clientcmd" +) + +var ( + secretName string + secretNamespace string + tokenName string + domain string + clientID string + minRefreshTime time.Duration + refreshThreshold time.Duration + kubeconfig string +) + +const ( + annotationStatus = "zora.undistro.io/status" +) + +type SecretStatus struct { + LastRefreshTime time.Time `json:"lastRefreshTime"` + NextScheduledRefresh time.Time `json:"nextScheduledRefresh"` + TokenExpiry time.Time `json:"tokenExpiry"` +} + +type Controller struct { + clientset *kubernetes.Clientset + refreshCh chan struct{} + mutex sync.Mutex +} + +func init() { + flag.StringVar(&secretName, "secret-name", "oauth-tokens", "Name of the secret containing OAuth tokens") + flag.StringVar(&secretNamespace, "namespace", "default", "Namespace of the secret") + flag.StringVar(&tokenName, "token-name", "token", "Name of token within the secret's data") + flag.StringVar(&domain, "domain", "", "Domain name associated with the token") + flag.StringVar(&clientID, "client-id", "", "Client ID associated with the token") + flag.DurationVar(&minRefreshTime, "min-refresh-time", 5*time.Minute, "Minimum time between refresh attempts") + flag.DurationVar(&refreshThreshold, "refresh-threshold", 5*time.Minute, "Time before token expiry to attempt refresh") + flag.StringVar(&kubeconfig, "kubeconfig", "", "Path to kubeconfig file. If not set, in-cluster config will be used") +} + +func main() { + flag.Parse() + + // Set up Kubernetes client + var config *rest.Config + var err error + if kubeconfig == "" { + config, err = rest.InClusterConfig() + } else { + config, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + } + if err != nil { + panic(err) + } + clientset, err := kubernetes.NewForConfig(config) + if err != nil { + panic(err) + } + + controller := &Controller{ + clientset: clientset, + refreshCh: make(chan struct{}, 1), + } + + // Create a watch on the secret + watchlist := cache.NewListWatchFromClient( + clientset.CoreV1().RESTClient(), + "secrets", + secretNamespace, + fields.OneTermEqualSelector("metadata.name", secretName), + ) + + _, informer := cache.NewInformer( + watchlist, + &v1.Secret{}, + 0, + cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { controller.handleSecretChange("added", obj) }, + UpdateFunc: func(_, newObj interface{}) { controller.handleSecretChange("updated", newObj) }, + }, + ) + + // Start the informer + stop := make(chan struct{}) + defer close(stop) + go informer.Run(stop) + + // Start the refresh loop + controller.refreshLoop(context.Background()) +} + +func (c *Controller) handleSecretChange(operation string, obj interface{}) { + secret, ok := obj.(*v1.Secret) + if !ok { + slog.Error(fmt.Sprintf("Error: Unexpected object type %s", reflect.TypeOf(secret))) + return + } + + slog.Info(fmt.Sprintf("Secret Change: %s %s", secret.Name, operation)) + + // Trigger an asynchronous refresh + select { + case c.refreshCh <- struct{}{}: + slog.Info("Secret Change:Triggered asynchronous refresh") + default: + slog.Info("Secret Change:Refresh already queued") + } +} + +func (c *Controller) refreshLoop(ctx context.Context) { + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) + + for { + c.mutex.Lock() + _, expiryTime, _, err := c.getTokenData(ctx) + c.mutex.Unlock() + if err != nil { + slog.Error(fmt.Sprintf("Error getting token data: %v", err)) + time.Sleep(minRefreshTime) + continue + } + + timeUntilExpiry := time.Until(expiryTime) + refreshTime := timeUntilExpiry - refreshThreshold + + if refreshTime < minRefreshTime { + refreshTime = minRefreshTime + } + + slog.Info(fmt.Sprintf("Refresh Loop: next refresh scheduled in %v", refreshTime)) + + select { + case <-time.After(refreshTime): + slog.Info("Refresh Loop: refreshing token due to scheduled refresh") + case <-c.refreshCh: + slog.Info("Refresh Loop: refreshing token due to secret change") + case sig := <-signalCh: + slog.Info(fmt.Sprintf("Refresh Loop: received signal %s, terminating application", sig)) + return + } + + c.refreshTokenIfNeeded(ctx) + } +} + +func (c *Controller) getTokenData(ctx context.Context) (*authentication.TokenData, time.Time, *SecretStatus, error) { + secret, err := c.clientset.CoreV1().Secrets(secretNamespace).Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + return nil, time.Time{}, nil, err + } + + tokenData, err := authentication.ParseTokenData(secret.Data[tokenName]) + if err != nil { + return nil, time.Time{}, nil, err + } + + expiryTime, err := authentication.GetJWTExpiry(tokenData) + if err != nil { + return nil, time.Time{}, nil, err + } + + secretStatus := getSecretStatus(secret) + return tokenData, expiryTime, secretStatus, nil +} + +func (c *Controller) refreshTokenIfNeeded(ctx context.Context) { + c.mutex.Lock() + defer c.mutex.Unlock() + + tokenData, expiryTime, secretStatus, err := c.getTokenData(ctx) + if err != nil { + slog.Error(fmt.Sprintf("Error getting token data: %v", err)) + return + } + + if time.Until(expiryTime) > refreshThreshold { + slog.Info("Refresh Loop: token is still valid, no need to refresh") + if secretStatus == nil || !secretStatus.TokenExpiry.Equal(expiryTime) { + // Update secret status + err = c.updateSecretStatus(ctx, time.Now(), expiryTime) + if err != nil { + slog.Error(fmt.Sprintf("Error updating secret status: %v", err)) + } + } + } else { + tokenData, err = refreshToken(domain, clientID, tokenData.RefreshToken) + if err != nil { + slog.Error(fmt.Sprintf("Error refreshing token: %v", err)) + return + } + + err = c.updateSecret(ctx, tokenData) + if err != nil { + slog.Error(fmt.Sprintf("Error updating secret: %v", err)) + return + } + + slog.Info("Refresh Loop: token refreshed successfully") + } +} + +func refreshToken(domain, clientID, refreshToken string) (*authentication.TokenData, error) { + url := fmt.Sprintf("https://%s/oauth/token", domain) + data := fmt.Sprintf("grant_type=refresh_token&client_id=%s&refresh_token=%s", clientID, refreshToken) + + resp, err := http.Post(url, "application/x-www-form-urlencoded", bytes.NewBufferString(data)) + if err != nil { + return nil, fmt.Errorf("failed to request device code: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response status: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var tokenData authentication.TokenData + if err := json.Unmarshal(body, &tokenData); err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return &tokenData, nil +} + +func (c *Controller) updateSecret(ctx context.Context, tokenData *authentication.TokenData) error { + secret, err := c.clientset.CoreV1().Secrets(secretNamespace).Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + return err + } + + expiryTime, _ := authentication.GetJWTExpiry(tokenData) + status := newSecretStatus(expiryTime, time.Now()) + err = setSecretStatus(secret, status) + if err != nil { + return err + } + + tokenBytes, err := json.Marshal(tokenData) + if err != nil { + return err + } + + secret.Data[tokenName] = tokenBytes + _, err = c.clientset.CoreV1().Secrets(secretNamespace).Update(ctx, secret, metav1.UpdateOptions{}) + return err +} + +func (c *Controller) updateSecretStatus(ctx context.Context, lastRefreshTime, tokenExpiry time.Time) error { + status := newSecretStatus(tokenExpiry, lastRefreshTime) + secret := v1.Secret{} + err := setSecretStatus(&secret, status) + if err != nil { + return err + } + + patch, err := json.Marshal(secret) + if err != nil { + return err + } + + _, err = c.clientset.CoreV1().Secrets(secretNamespace).Patch(ctx, secretName, types.StrategicMergePatchType, patch, metav1.PatchOptions{}) + return err +} + +func newSecretStatus(tokenExpiry time.Time, lastRefreshTime time.Time) *SecretStatus { + nextScheduledRefresh := tokenExpiry.Add(-refreshThreshold) + + status := SecretStatus{ + LastRefreshTime: lastRefreshTime, + NextScheduledRefresh: nextScheduledRefresh, + TokenExpiry: tokenExpiry, + } + return &status +} + +func getSecretStatus(secret *v1.Secret) *SecretStatus { + if secret != nil && secret.Annotations != nil { + statusVal := secret.Annotations[annotationStatus] + secretStatus := SecretStatus{} + err := json.Unmarshal([]byte(statusVal), &secretStatus) + + if err == nil { + return &secretStatus + } + } + return nil +} + +func setSecretStatus(secret *v1.Secret, status *SecretStatus) error { + if secret == nil { + return errors.New("missing secret") + } + if status == nil { + return errors.New("missing status") + } + if secret.Annotations == nil { + secret.Annotations = map[string]string{} + } + statusVal, err := json.Marshal(status) + if err != nil { + return err + } + secret.Annotations[annotationStatus] = string(statusVal) + return nil +} diff --git a/go.mod b/go.mod index cc7d0901..e1a5c0a9 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/evanphx/json-patch v5.7.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect github.com/fatih/color v1.17.0 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 github.com/go-errors/errors v1.4.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect @@ -47,6 +47,7 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect diff --git a/go.sum b/go.sum index e4678317..4690e8f0 100644 --- a/go.sum +++ b/go.sum @@ -247,6 +247,7 @@ github.com/goccy/go-yaml v1.9.5/go.mod h1:U/jl18uSupI5rdI2jmuCswEA2htH9eXfferR3K github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= diff --git a/internal/saas/client.go b/internal/saas/client.go index 8c1a24db..7934a8ca 100644 --- a/internal/saas/client.go +++ b/internal/saas/client.go @@ -24,6 +24,8 @@ import ( "path" "github.com/undistro/zora/api/zora/v1alpha2" + "github.com/undistro/zora/pkg/authentication" + "github.com/undistro/zora/pkg/filemonitor" ) const ( @@ -48,22 +50,29 @@ type Client interface { } type client struct { - client *http.Client - baseURL *url.URL - workspaceID string - version string + client *http.Client + baseURL *url.URL + workspaceID string + version string + tokenMonitor *filemonitor.FileMonitor } -func NewClient(baseURL, version, workspaceID string, httpclient *http.Client) (Client, error) { +func NewClient(baseURL, version, workspaceID string, httpclient *http.Client, tokenPath string, done <-chan struct{}) (Client, error) { u, err := validateURL(baseURL) if err != nil { return nil, err } + tokenMonitor := filemonitor.NewFileMonitor(tokenPath, func(content []byte) (any, error) { + return authentication.ParseTokenData(content) + }) + go tokenMonitor.MonitorFile(done) + return &client{ - version: version, - baseURL: u, - workspaceID: workspaceID, - client: httpclient, + version: version, + baseURL: u, + workspaceID: workspaceID, + client: httpclient, + tokenMonitor: tokenMonitor, }, nil } @@ -77,6 +86,7 @@ func (r *client) PutCluster(ctx context.Context, cluster Cluster) error { if err != nil { return err } + r.addAuthorizationHeader(req) req.Header.Set("content-type", "application/json") req.Header.Set(versionHeader, r.version) res, err := r.client.Do(req) @@ -93,6 +103,7 @@ func (r *client) DeleteCluster(ctx context.Context, namespace, name string) erro if err != nil { return err } + r.addAuthorizationHeader(req) req.Header.Set(versionHeader, r.version) res, err := r.client.Do(req) if err != nil { @@ -112,6 +123,7 @@ func (r *client) PutClusterScan(ctx context.Context, namespace, name string, plu if err != nil { return err } + r.addAuthorizationHeader(req) req.Header.Set("content-type", "application/json") req.Header.Set(versionHeader, r.version) res, err := r.client.Do(req) @@ -132,6 +144,7 @@ func (r *client) PutVulnerabilityReport(ctx context.Context, namespace, name str if err != nil { return err } + r.addAuthorizationHeader(req) req.Header.Set("content-type", "application/json") req.Header.Set(versionHeader, r.version) res, err := r.client.Do(req) @@ -148,8 +161,9 @@ func (r *client) DeleteClusterScan(ctx context.Context, namespace, name string) if err != nil { return err } - res, err := r.client.Do(req) + r.addAuthorizationHeader(req) req.Header.Set(versionHeader, r.version) + res, err := r.client.Do(req) if err != nil { return err } @@ -167,6 +181,7 @@ func (r *client) PutClusterStatus(ctx context.Context, namespace, name string, p if err != nil { return err } + r.addAuthorizationHeader(req) req.Header.Set("content-type", "application/json") req.Header.Set(versionHeader, r.version) res, err := r.client.Do(req) @@ -188,6 +203,15 @@ func (r *client) clusterURL(version, namespace, name string, extra ...string) st return u.String() } +func (r *client) addAuthorizationHeader(req *http.Request) { + tokenContent := r.tokenMonitor.GetContent() + if tokenContent != nil { + if tokenData, ok := tokenContent.(*authentication.TokenData); ok { + req.Header.Add("Authorization", fmt.Sprintf("%s %s", tokenData.TokenType, tokenData.AccessToken)) + } + } +} + func validateURL(u string) (*url.URL, error) { uri, err := url.ParseRequestURI(u) if err != nil { diff --git a/pkg/authentication/tokendata.go b/pkg/authentication/tokendata.go new file mode 100644 index 00000000..22bd9e9f --- /dev/null +++ b/pkg/authentication/tokendata.go @@ -0,0 +1,73 @@ +// Copyright 2024 Undistro Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authentication + +import ( + "encoding/base64" + "encoding/json" + "errors" + "strings" + "time" + + "github.com/golang-jwt/jwt" +) + +type TokenData struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` +} + +func ParseTokenData(data []byte) (*TokenData, error) { + if data == nil { + return nil, nil + } + var tokenData TokenData + err := json.Unmarshal(data, &tokenData) + if err != nil { + return nil, err + } + return &tokenData, nil +} + +func GetJWTExpiry(tokenData *TokenData) (time.Time, error) { + if tokenData == nil { + return time.Time{}, errors.New("Missing token data") + } + + token := tokenData.AccessToken + parts := strings.Split(token, ".") + if len(parts) != 3 { + return time.Time{}, errors.New("invalid token format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Time{}, err + } + + var claims jwt.MapClaims + err = json.Unmarshal(payload, &claims) + if err != nil { + return time.Time{}, err + } + + exp, ok := claims["exp"].(float64) + if !ok { + return time.Time{}, errors.New("invalid expiry claim") + } + + return time.Unix(int64(exp), 0), nil +} diff --git a/pkg/filemonitor/filemonitor.go b/pkg/filemonitor/filemonitor.go new file mode 100644 index 00000000..91379d13 --- /dev/null +++ b/pkg/filemonitor/filemonitor.go @@ -0,0 +1,293 @@ +// Copyright 2024 Undistro Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filemonitor + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "slices" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/go-logr/logr" + ctrl "sigs.k8s.io/controller-runtime" +) + +type FileContentProcessor func([]byte) (any, error) + +type FileMonitor struct { + mutex sync.RWMutex + fileContent any + filePath string + process FileContentProcessor + log logr.Logger +} + +func NewFileMonitor(filePath string, process FileContentProcessor) *FileMonitor { + log := ctrl.Log.WithValues("service", "MonitorFile", "monitored_file", filePath) + return &FileMonitor{ + filePath: filePath, + process: process, + log: log, + } +} + +func (fm *FileMonitor) logError(err error, msg string) { + fm.log.Error(err, msg) +} + +func (fm *FileMonitor) logInfo(msg string, args ...any) { + if len(args) > 0 { + fm.log.Info(fmt.Sprintf(msg, args...)) + } else { + fm.log.Info(msg) + } +} + +func (fm *FileMonitor) updateContent() error { + content, err := os.ReadFile(fm.filePath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + content = nil + } + + var contentVal any + if fm.process != nil { + contentVal, err = fm.process(content) + if err != nil { + return err + } + } else { + contentVal = content + } + fm.mutex.Lock() + defer fm.mutex.Unlock() + fm.fileContent = contentVal + return nil +} + +func (fm *FileMonitor) GetContent() any { + fm.mutex.RLock() + defer fm.mutex.RUnlock() + return fm.fileContent +} + +func (fm *FileMonitor) MonitorFile(doneCh <-chan struct{}) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + fm.logError(err, "Error creating watcher") + return + } + defer watcher.Close() + errCh := make(chan error) + + go func() { + trackedFiles, err := fm.handleFileWatch(watcher) + if err != nil { + errCh <- err + return + } + + err = fm.updateContent() + if err != nil { + errCh <- err + return + } + + fileWatch := false + for { + if fileWatch { + newTrackedFiles, err := fm.handleFileWatch(watcher) + if err != nil { + fm.logError(err, "Error handling file watch, pausing for a few seconds") + time.Sleep(10 * time.Second) + continue + } + trackedFiles = newTrackedFiles + err = fm.updateContent() + if err != nil { + fm.logError(err, "Error updating content") + } else { + fm.logInfo("File content updated successfully") + } + fileWatch = false + } + select { + case event, ok := <-watcher.Events: + if !ok { + errCh <- errors.New("watcher closed event channel") + return + } + if _, ok := trackedFiles[event.Name]; ok && event.Has(fsnotify.Write|fsnotify.Remove|fsnotify.Create) { + if event.Has(fsnotify.Write) { + fm.logInfo("File %s has been modified", event.Name) + } else if event.Has(fsnotify.Remove) { + fm.logInfo("File %s has been removed", event.Name) + } else if event.Has(fsnotify.Create) { + fm.logInfo("File %s has been created", event.Name) + } + fileWatch = true + } + case err, ok := <-watcher.Errors: + fm.logError(err, "Unexpected watch error") + if !ok { + errCh <- err + return + } + } + } + }() + + fm.logInfo("Starting to monitor %s", fm.filePath) + select { + case <-doneCh: + fm.logInfo("Termination requested, stopping file watch") + case err := <-errCh: + fm.logError(err, "Unexpected error, stopping file watch") + } +} + +func (fm *FileMonitor) handleFileWatch(watcher *fsnotify.Watcher) (map[string]struct{}, error) { + // Starting from the watch file, walk the links and track their locations and directories. + // We watch modifications for all links and the final file, we do not track any intermediate directories + + trackedFiles, trackedDirs, err := walkLinks(filepath.Split(fm.filePath)) + if err != nil { + return nil, err + } + currentTrackedDirs := watcher.WatchList() + slices.Sort(currentTrackedDirs) + if !slices.Equal(currentTrackedDirs, trackedDirs) { + currentTrackedDirsMap := generateMap(currentTrackedDirs) + trackedDirsMap := generateMap(trackedDirs) + unusedTrackedDirs := removeFromMap(currentTrackedDirsMap, trackedDirsMap) + newTrackedDirs := removeFromMap(trackedDirsMap, currentTrackedDirsMap) + for _, newTrackedDir := range newTrackedDirs { + fm.logInfo("Adding watch for %s", newTrackedDir) + err = watcher.Add(newTrackedDir) + if err != nil { + return nil, err + } + } + for _, unusedTrackedDir := range unusedTrackedDirs { + fm.logInfo("Removing watch for %s", unusedTrackedDir) + err = watcher.Remove(unusedTrackedDir) + if err != nil { + return nil, err + } + } + } + return trackedFiles, nil +} + +func walkLinks(prefix, suffix string) (map[string]struct{}, []string, error) { + prefix = filepath.Clean(prefix) + trackedFilesMap := map[string]struct{}{} + trackedDirsMap := map[string]struct{}{} + suffices := split(suffix) + for len(suffices) > 0 { + filePath := filepath.Join(prefix, suffices[0]) + fileInfo, err := os.Lstat(filePath) + if err != nil { + if !os.IsNotExist(err) { + return nil, nil, err + } + // file doesn't exist, we watch and stop at this point + trackedFilesMap[filePath] = struct{}{} + trackedDirsMap[prefix] = struct{}{} + break + } + fileMode := fileInfo.Mode() + switch { + case fileMode&fs.ModeDir != 0: + // Directory discovered, we move onto the next unless it's the final suffix + if len(suffices) == 1 { + trackedFilesMap[filePath] = struct{}{} + trackedDirsMap[prefix] = struct{}{} + suffices = nil + } else { + prefix = filePath + suffices = suffices[1:] + } + case fileMode&fs.ModeSymlink != 0: + // Symbolic link discovered, we watch this and also follow the link with the remaining suffices + if _, ok := trackedFilesMap[filePath]; ok { + // potential cycle discovered, we stop here + // Note this could miss some cases + suffices = nil + break + } + trackedFilesMap[filePath] = struct{}{} + trackedDirsMap[prefix] = struct{}{} + link, err := os.Readlink(filePath) + if err != nil { + return nil, nil, err + } + if filepath.IsAbs(link) { + prefix = "/" + } + suffices = append(split(link), suffices[1:]...) + case fileMode&fs.ModeType == 0: + // File discovered, either this is the target file or part of a previous link is invalid + // Either way, we stop at this point. + fallthrough + default: + // Everything else, we watch this in case it changes + trackedFilesMap[filePath] = struct{}{} + trackedDirsMap[prefix] = struct{}{} + suffices = nil + } + } + trackedDirs := []string{} + for trackedDir := range trackedDirsMap { + trackedDirs = append(trackedDirs, trackedDir) + } + slices.Sort(trackedDirs) + return trackedFilesMap, trackedDirs, nil +} + +func split(file string) []string { + dir, base := filepath.Split(file) + dir = filepath.Clean(dir) + if dir == "." || dir == "/" { + return []string{base} + } else { + return append(split(dir), base) + } +} + +func removeFromMap(sourceMap, removalsMap map[string]struct{}) []string { + result := []string{} + for key := range sourceMap { + if _, ok := removalsMap[key]; !ok { + result = append(result, key) + } + } + return result +} + +func generateMap(sliceValues []string) map[string]struct{} { + result := map[string]struct{}{} + for _, value := range sliceValues { + result[value] = struct{}{} + } + return result +} diff --git a/pkg/filemonitor/filemonitor_linux_test.go b/pkg/filemonitor/filemonitor_linux_test.go new file mode 100644 index 00000000..438eec6d --- /dev/null +++ b/pkg/filemonitor/filemonitor_linux_test.go @@ -0,0 +1,113 @@ +// Copyright 2024 Undistro Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package filemonitor + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestFileMonitorMonitorLink(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "example") + if err != nil { + t.Fatal("Failed to create temp dir:", err) + } + defer os.RemoveAll(tmpDir) + + tmpfileName := filepath.Join(tmpDir, "test") + + // Write initial content + initialContent := "Initial content" + err = os.WriteFile(tmpfileName, []byte(initialContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + + linkFileName := tmpfileName + "_link" + fm := NewFileMonitor(linkFileName, processContent) + + // Start monitoring in a goroutine + done := make(chan struct{}) + defer close(done) + + go fm.MonitorFile(done) + + // Give some time for the initial read + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != nil { + t.Errorf("Expected empty content, got %q", fm.GetContent()) + } + + // create the link + os.Symlink(tmpfileName, linkFileName) + + // Wait for the link to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != initialContent { + t.Errorf("Expected initial content %q, got %q", initialContent, fm.GetContent()) + } + + // Update file content + newContent := "Updated content" + err = os.WriteFile(tmpfileName, []byte(newContent), 0644) + if err != nil { + t.Fatal("Failed to write updated content:", err) + } + + // Wait for the file change to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != newContent { + t.Errorf("Expected updated content %q, got %q", newContent, fm.GetContent()) + } + + // Remove the link + err = os.Remove(linkFileName) + if err != nil { + t.Fatal("Failed to remove link:", err) + } + + // Wait for the file deletion to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != nil { + t.Errorf("Expected empty content, got %q", fm.GetContent()) + } + + // Write new file contents + recreatedFileContent := "Recreated content" + newTmpFileName := tmpfileName + "-2" + err = os.WriteFile(newTmpFileName, []byte(recreatedFileContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + + // create the new link + os.Symlink(newTmpFileName, linkFileName) + + // Wait for the link to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != recreatedFileContent { + t.Errorf("Expected recreated content %q, got %q", recreatedFileContent, fm.GetContent()) + } +} diff --git a/pkg/filemonitor/filemonitor_test.go b/pkg/filemonitor/filemonitor_test.go new file mode 100644 index 00000000..abe62a61 --- /dev/null +++ b/pkg/filemonitor/filemonitor_test.go @@ -0,0 +1,226 @@ +// Copyright 2024 Undistro Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filemonitor + +import ( + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func processContent(content []byte) (any, error) { + if content == nil { + return nil, nil + } else { + return string(content), nil + } +} + +func TestFileMonitorUpdateContent(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "example") + if err != nil { + t.Fatal("Failed to create temp dir:", err) + } + defer os.RemoveAll(tmpDir) + + tmpfileName := filepath.Join(tmpDir, "test") + + // Write initial content + initialContent := "Hello, World!" + err = os.WriteFile(tmpfileName, []byte(initialContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + + fm := NewFileMonitor(tmpfileName, processContent) + err = fm.updateContent() + if err != nil { + t.Fatal("Failed to update content:", err) + } + + if fm.GetContent() != initialContent { + t.Errorf("Expected content %q, got %q", initialContent, fm.GetContent()) + } + + // Update file content + newContent := "Updated content" + err = os.WriteFile(tmpfileName, []byte(newContent), 0644) + if err != nil { + t.Fatal("Failed to write updated content:", err) + } + + err = fm.updateContent() + if err != nil { + t.Fatal("Failed to update content:", err) + } + + if fm.GetContent() != newContent { + t.Errorf("Expected updated content %q, got %q", newContent, fm.GetContent()) + } +} + +func TestFileMonitorMonitorFile(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "example") + if err != nil { + t.Fatal("Failed to create temp dir:", err) + } + defer os.RemoveAll(tmpDir) + + tmpfileName := filepath.Join(tmpDir, "test") + + // Write initial content + initialContent := "Initial content" + err = os.WriteFile(tmpfileName, []byte(initialContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + + fm := NewFileMonitor(tmpfileName, processContent) + + // Start monitoring in a goroutine + done := make(chan struct{}) + defer close(done) + + go fm.MonitorFile(done) + + // Give some time for the initial read + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != initialContent { + t.Errorf("Expected initial content %q, got %q", initialContent, fm.GetContent()) + } + + // Update file content + newContent := "Updated content" + err = os.WriteFile(tmpfileName, []byte(newContent), 0644) + if err != nil { + t.Fatal("Failed to write updated content:", err) + } + + // Wait for the file change to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != newContent { + t.Errorf("Expected updated content %q, got %q", newContent, fm.GetContent()) + } + + // Remove the file + err = os.Remove(tmpfileName) + if err != nil { + t.Fatal("Failed to remove file:", err) + } + + // Wait for the file deletion to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != nil { + t.Errorf("Expected empty content, got %q", fm.GetContent()) + } + + // Write recreated file contents + recreatedFileContent := "Recreated content" + err = os.WriteFile(tmpfileName, []byte(recreatedFileContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + + // Wait for the file change to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != recreatedFileContent { + t.Errorf("Expected recreated content %q, got %q", recreatedFileContent, fm.GetContent()) + } + + // Test moved file + movedFileName := filepath.Join(tmpDir, "moved") + + // Write moved file contents + movedFileContent := "Moved content" + err = os.WriteFile(movedFileName, []byte(movedFileContent), 0644) + if err != nil { + t.Fatal("Failed to write to temp file:", err) + } + err = os.Rename(movedFileName, tmpfileName) + if err != nil { + t.Fatal("Failed to rename temp file:", err) + } + + // Wait for the file change to be detected + time.Sleep(100 * time.Millisecond) + + if fm.GetContent() != movedFileContent { + t.Errorf("Expected moved content %q, got %q", movedFileContent, fm.GetContent()) + } +} + +func TestFileMonitorConcurrentAccess(t *testing.T) { + testFile, err := os.CreateTemp("", "testfile") + if err != nil { + t.Errorf("Error creating temp file: %v", err) + } + defer testFile.Close() + testpath := testFile.Name() + + fm := NewFileMonitor(filepath.Join(testpath), processContent) + fm.fileContent = "Test content" + + // Simulate concurrent reads + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + content := fm.GetContent() + if content != "Test content" { + t.Errorf("Expected content %q, got %q", "Test content", content) + } + }() + } + wg.Wait() + + // Simulate concurrent writes + os.WriteFile(testpath, []byte("This is a test"), 0777) + defer os.Remove(testpath) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := fm.updateContent() + if err != nil { + t.Errorf("Error updating content: %v", err) + } + }() + } + + wg.Wait() +} + +func TestFileMonitorInvalidFile(t *testing.T) { + // Create a directory instead of a file + tmpdir, err := os.MkdirTemp("", "example") + if err != nil { + t.Fatal("Failed to create temp directory:", err) + } + defer os.RemoveAll(tmpdir) + + fm := NewFileMonitor(tmpdir, processContent) + err = fm.updateContent() + if err == nil { + t.Error("Expected error for directory, got nil") + } +}