From 3e87e0c4559736f9476eba943bac8d67cde91aad Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:57:23 +0100 Subject: [PATCH] feat: use one transaction for `/admin/recovery/code` (#4225) --- selfservice/strategy/code/strategy.go | 1 + .../strategy/code/strategy_recovery_admin.go | 31 ++++++++++++------- selfservice/strategy/link/strategy.go | 1 + .../strategy/link/strategy_recovery.go | 15 +++++---- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index aaac228508fd..85070509402c 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -68,6 +68,7 @@ type ( x.WriterProvider x.LoggingProvider x.TracingProvider + x.TransactionPersistenceProvider config.Provider diff --git a/selfservice/strategy/code/strategy_recovery_admin.go b/selfservice/strategy/code/strategy_recovery_admin.go index 63aa36a90edd..d1626f8a3987 100644 --- a/selfservice/strategy/code/strategy_recovery_admin.go +++ b/selfservice/strategy/code/strategy_recovery_admin.go @@ -4,10 +4,13 @@ package code import ( + "context" "net/http" "net/url" "time" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/julienschmidt/httprouter" "github.com/pkg/errors" @@ -184,16 +187,12 @@ func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http. })). WithMetaLabel(text.NewInfoNodeLabelRecoveryCode()), ) + rawCode := GenerateCode() recoveryFlow.UI.Nodes. Append(node.NewInputField("method", s.RecoveryStrategyID(), node.CodeGroup, node.InputAttributeTypeSubmit). WithMetaLabel(text.NewInfoNodeLabelContinue())) - if err := s.deps.RecoveryFlowPersister().CreateRecoveryFlow(ctx, recoveryFlow); err != nil { - s.deps.Writer().WriteError(w, r, err) - return - } - id, err := s.deps.IdentityPool().GetIdentity(ctx, p.IdentityID, identity.ExpandDefault) if notFoundErr := sqlcon.ErrNoRows; errors.As(err, ¬FoundErr) { s.deps.Writer().WriteError(w, r, notFoundErr.WithReasonf("could not find identity")) @@ -203,14 +202,22 @@ func (s *Strategy) createRecoveryCodeForIdentity(w http.ResponseWriter, r *http. return } - rawCode := GenerateCode() + if err := s.deps.TransactionalPersisterProvider().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if err := s.deps.RecoveryFlowPersister().CreateRecoveryFlow(ctx, recoveryFlow); err != nil { + return err + } + + if _, err := s.deps.RecoveryCodePersister().CreateRecoveryCode(ctx, &CreateRecoveryCodeParams{ + RawCode: rawCode, + CodeType: RecoveryCodeTypeAdmin, + ExpiresIn: expiresIn, + FlowID: recoveryFlow.ID, + IdentityID: id.ID, + }); err != nil { + return err + } - if _, err := s.deps.RecoveryCodePersister().CreateRecoveryCode(ctx, &CreateRecoveryCodeParams{ - RawCode: rawCode, - CodeType: RecoveryCodeTypeAdmin, - ExpiresIn: expiresIn, - FlowID: recoveryFlow.ID, - IdentityID: id.ID, + return nil }); err != nil { s.deps.Writer().WriteError(w, r, err) return diff --git a/selfservice/strategy/link/strategy.go b/selfservice/strategy/link/strategy.go index cdf8356cc4b3..5cb78378118b 100644 --- a/selfservice/strategy/link/strategy.go +++ b/selfservice/strategy/link/strategy.go @@ -43,6 +43,7 @@ type ( x.WriterProvider x.LoggingProvider x.TracingProvider + x.TransactionPersistenceProvider config.Provider diff --git a/selfservice/strategy/link/strategy_recovery.go b/selfservice/strategy/link/strategy_recovery.go index 0ad04d244817..6c92082e47ef 100644 --- a/selfservice/strategy/link/strategy_recovery.go +++ b/selfservice/strategy/link/strategy_recovery.go @@ -10,6 +10,8 @@ import ( "net/url" "time" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/julienschmidt/httprouter" "github.com/pkg/errors" @@ -171,11 +173,6 @@ func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http. return } - if err := s.d.RecoveryFlowPersister().CreateRecoveryFlow(r.Context(), req); err != nil { - s.d.Writer().WriteError(w, r, err) - return - } - id, err := s.d.IdentityPool().GetIdentity(r.Context(), p.IdentityID, identity.ExpandDefault) if errors.Is(err, sqlcon.ErrNoRows) { s.d.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("The requested identity id does not exist.").WithWrap(err))) @@ -186,7 +183,13 @@ func (s *Strategy) createRecoveryLinkForIdentity(w http.ResponseWriter, r *http. } token := NewAdminRecoveryToken(id.ID, req.ID, expiresIn) - if err := s.d.RecoveryTokenPersister().CreateRecoveryToken(r.Context(), token); err != nil { + if err := s.d.TransactionalPersisterProvider().Transaction(r.Context(), func(ctx context.Context, c *pop.Connection) error { + if err := s.d.RecoveryFlowPersister().CreateRecoveryFlow(ctx, req); err != nil { + return err + } + + return s.d.RecoveryTokenPersister().CreateRecoveryToken(ctx, token) + }); err != nil { s.d.Writer().WriteError(w, r, err) return }