Skip to content

Commit

Permalink
Merge pull request cs3org#4454 from butonic/skip-unnecessary-received…
Browse files Browse the repository at this point in the history
…-share-retrieval

Skip unnecessary share retrieval
  • Loading branch information
dragonchaser authored Jan 18, 2024
2 parents bde86a3 + 5a60236 commit aa54276
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 150 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Bugfix: Skip unnecessary share retrieval

https://github.com/cs3org/reva/pull/4454
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/cs3org/reva/v2/internal/http/services/owncloud/ocs/response"
"github.com/cs3org/reva/v2/pkg/appctx"
"github.com/cs3org/reva/v2/pkg/conversions"
"github.com/cs3org/reva/v2/pkg/errtypes"
"github.com/cs3org/reva/v2/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/pkg/errors"
Expand Down Expand Up @@ -63,62 +64,110 @@ func (h *Handler) AcceptReceivedShare(w http.ResponseWriter, r *http.Request) {
return
}

rs, ocsResponse := getReceivedShareFromID(ctx, client, shareID)
receivedShare, ocsResponse := getReceivedShareFromID(ctx, client, shareID)
if ocsResponse != nil {
response.WriteOCSResponse(w, r, *ocsResponse, nil)
return
}

sharedResource, ocsResponse := getSharedResource(ctx, client, rs.Share.Share.ResourceId)
sharedResource, ocsResponse := getSharedResource(ctx, client, receivedShare.Share.ResourceId)
if ocsResponse != nil {
response.WriteOCSResponse(w, r, *ocsResponse, nil)
return
}

lrs, ocsResponse := getSharesList(ctx, client)
if ocsResponse != nil {
response.WriteOCSResponse(w, r, *ocsResponse, nil)
mount, unmountedShares, err := GetMountpointAndUnmountedShares(ctx, client, sharedResource.Info)
if err != nil {
response.WriteOCSError(w, r, response.MetaServerError.StatusCode, "could not determine mountpoint", err)
return
}

// first update the requested share
receivedShare.State = collaboration.ShareState_SHARE_STATE_ACCEPTED
// we need to add a path to the share
receivedShare.MountPoint = &provider.Reference{
Path: mount,
}

updateMask := &fieldmaskpb.FieldMask{Paths: []string{"state", "mount_point"}}
data, meta, err := h.updateReceivedShare(r.Context(), receivedShare, updateMask)
if err != nil {
// we log an error for affected shares, for the actual share we return an error
response.WriteOCSData(w, r, meta, data, err)
return
}
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})

// then update other unmounted shares to the same resource
for _, rs := range unmountedShares {
if rs.GetShare().GetId().GetOpaqueId() == shareID {
// we already updated this one
continue
}

rs.State = collaboration.ShareState_SHARE_STATE_ACCEPTED
// set the same mountpoint as for the requested received share
rs.MountPoint = &provider.Reference{
Path: mount,
}

_, _, err := h.updateReceivedShare(r.Context(), rs, updateMask)
if err != nil {
// we log an error for affected shares, the actual share was successful
appctx.GetLogger(ctx).Error().Err(err).Str("received_share", shareID).Str("affected_share", rs.GetShare().GetId().GetOpaqueId()).Msg("could not update affected received share")
}
}
}

// GetMountpointAndUnmountedShares returns a new or existing mountpoint for the given info and produces a list of unmounted received shares for the same resource
func GetMountpointAndUnmountedShares(ctx context.Context, gwc gateway.GatewayAPIClient, info *provider.ResourceInfo) (string, []*collaboration.ReceivedShare, error) {
unmountedShares := []*collaboration.ReceivedShare{}
receivedShares, err := listReceivedShares(ctx, gwc)
if err != nil {
return "", unmountedShares, err
}

// we need to sort the received shares by mount point in order to make things easier to evaluate.
base := path.Base(sharedResource.GetInfo().GetPath())
mount := base
var mountedShares []*collaboration.ReceivedShare
sharesToAccept := map[string]bool{shareID: true}
for _, s := range lrs.Shares {
if utils.ResourceIDEqual(s.Share.ResourceId, rs.Share.Share.GetResourceId()) {
mount := filepath.Clean(info.Name)
existingMountpoint := ""
mountedShares := make([]*collaboration.ReceivedShare, 0, len(receivedShares))
for _, s := range receivedShares {
if utils.ResourceIDEqual(s.Share.ResourceId, info.GetId()) {
if s.State == collaboration.ShareState_SHARE_STATE_ACCEPTED {
mount = s.MountPoint.Path
// a share to the resource already exists and is mounted, remember the mount point
_, err := utils.GetResourceByID(ctx, s.Share.ResourceId, gwc)
if err == nil {
existingMountpoint = s.MountPoint.Path
}
} else {
sharesToAccept[s.Share.Id.OpaqueId] = true
}
} else {
if s.State == collaboration.ShareState_SHARE_STATE_ACCEPTED {
s.Hidden = h.getReceivedShareHideFlagFromShareID(r.Context(), shareID)
mountedShares = append(mountedShares, s)
// a share to the resource already exists but is not mounted, collect the unmounted share
unmountedShares = append(unmountedShares, s)
}
}

if s.State == collaboration.ShareState_SHARE_STATE_ACCEPTED {
mountedShares = append(mountedShares, s)
}
}

compareMountPoint := func(i, j int) bool {
sort.Slice(mountedShares, func(i, j int) bool {
return mountedShares[i].MountPoint.Path > mountedShares[j].MountPoint.Path
})

if existingMountpoint != "" {
// we want to reuse the same mountpoint for all unmounted shares to the same resource
return existingMountpoint, unmountedShares, nil
}
sort.Slice(mountedShares, compareMountPoint)

// now we have a list of shares, we want to iterate over all of them and check for name collisions
// we have a list of shares, we want to iterate over all of them and check for name collisions
for i, ms := range mountedShares {
if ms.MountPoint.Path == mount {
// does the shared resource still exist?
res, err := client.Stat(ctx, &provider.StatRequest{
Ref: &provider.Reference{
ResourceId: ms.Share.ResourceId,
},
})
if err == nil && res.Status.Code == rpc.Code_CODE_OK {
_, err := utils.GetResourceByID(ctx, ms.Share.ResourceId, gwc)
if err == nil {
// The mount point really already exists, we need to insert a number into the filename
ext := filepath.Ext(base)
name := strings.TrimSuffix(base, ext)
ext := filepath.Ext(mount)
name := strings.TrimSuffix(mount, ext)
// be smart about .tar.(gz|bz) files
if strings.HasSuffix(name, ".tar") {
name = strings.TrimSuffix(name, ".tar")
Expand All @@ -130,26 +179,7 @@ func (h *Handler) AcceptReceivedShare(w http.ResponseWriter, r *http.Request) {
// TODO we could delete shares here if the stat returns code NOT FOUND ... but listening for file deletes would be better
}
}
// we need to add a path to the share
receivedShare := &collaboration.ReceivedShare{
Share: &collaboration.Share{
Id: &collaboration.ShareId{OpaqueId: shareID},
},
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
Hidden: h.getReceivedShareHideFlagFromShareID(r.Context(), shareID),
MountPoint: &provider.Reference{
Path: mount,
},
}
updateMask := &fieldmaskpb.FieldMask{Paths: []string{"state", "hidden", "mount_point"}}

for id := range sharesToAccept {
data := h.updateReceivedShare(w, r, receivedShare, updateMask)
// only render the data for the changed share
if id == shareID && data != nil {
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
}
}
return mount, unmountedShares, nil
}

// RejectReceivedShare handles DELETE Requests on /apps/files_sharing/api/v1/shares/{shareid}
Expand All @@ -166,15 +196,15 @@ func (h *Handler) RejectReceivedShare(w http.ResponseWriter, r *http.Request) {
Share: &collaboration.Share{
Id: &collaboration.ShareId{OpaqueId: shareID},
},
State: collaboration.ShareState_SHARE_STATE_REJECTED,
Hidden: h.getReceivedShareHideFlagFromShareID(r.Context(), shareID),
State: collaboration.ShareState_SHARE_STATE_REJECTED,
}
updateMask := &fieldmaskpb.FieldMask{Paths: []string{"state", "hidden"}}
updateMask := &fieldmaskpb.FieldMask{Paths: []string{"state"}}

data := h.updateReceivedShare(w, r, receivedShare, updateMask)
if data != nil {
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
data, meta, err := h.updateReceivedShare(r.Context(), receivedShare, updateMask)
if err != nil {
response.WriteOCSData(w, r, meta, nil, err)
}
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
}

func (h *Handler) UpdateReceivedShare(w http.ResponseWriter, r *http.Request) {
Expand All @@ -199,18 +229,17 @@ func (h *Handler) UpdateReceivedShare(w http.ResponseWriter, r *http.Request) {

rs, _ := getReceivedShareFromID(r.Context(), client, shareID)
if rs != nil && rs.Share != nil {
receivedShare.State = rs.Share.State
receivedShare.State = rs.State
}

data := h.updateReceivedShare(w, r, receivedShare, updateMask)
if data != nil {
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
data, meta, err := h.updateReceivedShare(r.Context(), receivedShare, updateMask)
if err != nil {
response.WriteOCSData(w, r, meta, nil, err)
}
// TODO: do we need error handling here?
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
}

func (h *Handler) updateReceivedShare(w http.ResponseWriter, r *http.Request, receivedShare *collaboration.ReceivedShare, fieldMask *fieldmaskpb.FieldMask) *conversions.ShareData {
ctx := r.Context()
func (h *Handler) updateReceivedShare(ctx context.Context, receivedShare *collaboration.ReceivedShare, fieldMask *fieldmaskpb.FieldMask) (*conversions.ShareData, response.Meta, error) {
logger := appctx.GetLogger(ctx)

updateShareRequest := &collaboration.UpdateReceivedShareRequest{
Expand All @@ -220,51 +249,43 @@ func (h *Handler) updateReceivedShare(w http.ResponseWriter, r *http.Request, re

client, err := h.getClient()
if err != nil {
response.WriteOCSError(w, r, response.MetaServerError.StatusCode, "error getting grpc gateway client", err)
return nil
return nil, response.MetaServerError, errors.Wrap(err, "error getting grpc gateway client")
}

shareRes, err := client.UpdateReceivedShare(ctx, updateShareRequest)
if err != nil {
response.WriteOCSError(w, r, response.MetaServerError.StatusCode, "grpc update received share request failed", err)
return nil
return nil, response.MetaServerError, errors.Wrap(err, "grpc update received share request failed")
}

if shareRes.Status.Code != rpc.Code_CODE_OK {
if shareRes.Status.Code == rpc.Code_CODE_NOT_FOUND {
response.WriteOCSError(w, r, response.MetaNotFound.StatusCode, "not found", nil)
return nil
return nil, response.MetaNotFound, errors.New(shareRes.Status.Message)
}
response.WriteOCSError(w, r, response.MetaServerError.StatusCode, "grpc update received share request failed", errors.Errorf("code: %d, message: %s", shareRes.Status.Code, shareRes.Status.Message))
return nil
return nil, response.MetaServerError, errors.Errorf("grpc update received share request failed: code: %d, message: %s", shareRes.Status.Code, shareRes.Status.Message)
}

rs := shareRes.GetShare()

info, status, err := h.getResourceInfoByID(ctx, client, rs.Share.ResourceId)
if err != nil || status.Code != rpc.Code_CODE_OK {
h.logProblems(logger, status, err, "could not stat, skipping")
response.WriteOCSError(w, r, response.MetaServerError.StatusCode, "grpc get resource info failed", errors.Errorf("code: %d, message: %s", status.Code, status.Message))
return nil
return nil, response.MetaServerError, errors.Errorf("grpc get resource info failed: code: %d, message: %s", status.Code, status.Message)
}

data, err := conversions.CS3Share2ShareData(r.Context(), rs.Share)
if err != nil {
logger.Debug().Interface("share", rs.Share).Interface("shareData", data).Err(err).Msg("could not CS3Share2ShareData, skipping")
}
data := conversions.CS3Share2ShareData(ctx, rs.Share)

data.State = mapState(rs.GetState())
data.Hidden = rs.GetHidden()

h.addFileInfo(ctx, data, info)
h.mapUserIds(r.Context(), client, data)
h.mapUserIds(ctx, client, data)

if data.State == ocsStateAccepted {
// Needed because received shares can be jailed in a folder in the users home
data.Path = path.Join(h.sharePrefix, path.Base(info.Path))
}

return data
return data, response.MetaOK, nil
}

func (h *Handler) updateReceivedFederatedShare(w http.ResponseWriter, r *http.Request, shareID string, rejectShare bool) {
Expand Down Expand Up @@ -337,21 +358,8 @@ func (h *Handler) updateReceivedFederatedShare(w http.ResponseWriter, r *http.Re
response.WriteOCSSuccess(w, r, []*conversions.ShareData{data})
}

// getReceivedShareHideFlagFromShareId returns the hide flag of a received share based on its ID.
func (h *Handler) getReceivedShareHideFlagFromShareID(ctx context.Context, shareID string) bool {
client, err := h.getClient()
if err != nil {
return false
}
rs, _ := getReceivedShareFromID(ctx, client, shareID)
if rs != nil {
return rs.GetShare().GetHidden()
}
return false
}

// getReceivedShareFromID uses a client to the gateway to fetch a share based on its ID.
func getReceivedShareFromID(ctx context.Context, client gateway.GatewayAPIClient, shareID string) (*collaboration.GetReceivedShareResponse, *response.Response) {
func getReceivedShareFromID(ctx context.Context, client gateway.GatewayAPIClient, shareID string) (*collaboration.ReceivedShare, *response.Response) {
s, err := client.GetReceivedShare(ctx, &collaboration.GetReceivedShareRequest{
Ref: &collaboration.ShareReference{
Spec: &collaboration.ShareReference_Id{
Expand All @@ -376,7 +384,7 @@ func getReceivedShareFromID(ctx context.Context, client gateway.GatewayAPIClient
return nil, arbitraryOcsResponse(response.MetaBadRequest.StatusCode, e.Error())
}

return s, nil
return s.Share, nil
}

// getSharedResource attempts to get a shared resource from the storage from the resource reference.
Expand All @@ -403,23 +411,17 @@ func getSharedResource(ctx context.Context, client gateway.GatewayAPIClient, res
return res, nil
}

// getSharedResource gets the list of all shares for the current user.
func getSharesList(ctx context.Context, client gateway.GatewayAPIClient) (*collaboration.ListReceivedSharesResponse, *response.Response) {
shares, err := client.ListReceivedShares(ctx, &collaboration.ListReceivedSharesRequest{})
// listReceivedShares list all received shares for the current user.
func listReceivedShares(ctx context.Context, client gateway.GatewayAPIClient) ([]*collaboration.ReceivedShare, error) {
res, err := client.ListReceivedShares(ctx, &collaboration.ListReceivedSharesRequest{})
if err != nil {
e := errors.Wrap(err, "error getting shares list")
return nil, arbitraryOcsResponse(response.MetaNotFound.StatusCode, e.Error())
return nil, errtypes.InternalError("grpc list received shares request failed")
}

if shares.Status.Code != rpc.Code_CODE_OK {
if shares.Status.Code == rpc.Code_CODE_NOT_FOUND {
e := fmt.Errorf("not found")
return nil, arbitraryOcsResponse(response.MetaNotFound.StatusCode, e.Error())
}
e := fmt.Errorf(shares.GetStatus().GetMessage())
return nil, arbitraryOcsResponse(response.MetaServerError.StatusCode, e.Error())
if err := errtypes.NewErrtypeFromStatus(res.Status); err != nil {
return nil, err
}
return shares, nil
return res.Shares, nil
}

// arbitraryOcsResponse abstracts the boilerplate that is creating a response.Response struct.
Expand Down
Loading

0 comments on commit aa54276

Please sign in to comment.