Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup multipart download and allow for concurrent downloads #705

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 75 additions & 41 deletions Sources/Soto/Extensions/S3/S3+multipart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import Atomics
import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOPosix
import SotoCore
Expand Down Expand Up @@ -50,13 +51,15 @@ extension S3 {
/// - parameters:
/// - input: The GetObjectRequest shape that contains the details of the object request.
/// - partSize: Size of each part to be downloaded
/// - concurrentDownloads: How many downloads can you have running at one time
/// - outputStream: Function to be called for each downloaded part. Called with data block and file size
/// - returns: The complete file size once the multipart download has finished.
public func multipartDownload(
_ input: GetObjectRequest,
partSize: Int = 5 * 1024 * 1024,
concurrentDownloads: Int = 4,
logger: Logger = AWSClient.loggingDisabled,
outputStream: @escaping (ByteBuffer, Int64) async throws -> Void
outputStream: @escaping @Sendable (ByteBuffer, Int64) async throws -> Void
) async throws -> Int64 {
// get object size before downloading
let headRequest = S3.HeadObjectRequest(
Expand All @@ -77,48 +80,76 @@ extension S3 {
throw S3ErrorType.multipart.downloadEmpty(message: "Content length is unexpectedly zero")
}

// download part task
func downloadPartTask(offset: Int64, partSize: Int64) -> Task<GetObjectOutput, Swift.Error> {
let range = "bytes=\(offset)-\(offset + Int64(partSize - 1))"
let getRequest = S3.GetObjectRequest(
bucket: input.bucket,
key: input.key,
range: range,
sseCustomerAlgorithm: input.sseCustomerAlgorithm,
sseCustomerKey: input.sseCustomerKey,
sseCustomerKeyMD5: input.sseCustomerKeyMD5,
versionId: input.versionId
)
return Task {
try await getObject(getRequest, logger: logger)
}
}

// save part task
func savePart(downloadedPart: GetObjectOutput) async throws {
try await outputStream(downloadedPart.body.collect(upTo: .max), contentLength)
}

let partSize: Int64 = numericCast(partSize)
var offset = min(partSize, contentLength)
var downloadedPartTask = downloadPartTask(offset: 0, partSize: offset)
while offset < contentLength {
// wait for previous download
let downloadedPart = try await downloadedPartTask.value
try await withThrowingTaskGroup(of: (Int, ByteBuffer).self) { group in
/// Structure used to store downloaded buffers and then save them as and when
/// needed
struct DownloadedBuffers {
let outputStream: @Sendable (ByteBuffer) async throws -> Void
var buffers: [ByteBuffer?]
var bufferSavedIndex: Int

init(numberOfBuffers: Int, outputStream: @escaping @Sendable (ByteBuffer) async throws -> Void) {
self.outputStream = outputStream
self.buffers = Array(repeating: nil, count: numberOfBuffers)
self.bufferSavedIndex = 0
}

// start next download
let downloadPartSize = min(partSize, contentLength - offset)
downloadedPartTask = downloadPartTask(offset: offset, partSize: downloadPartSize)
offset += downloadPartSize
mutating func saveBuffer(index: Int, buffer: ByteBuffer) async throws {
assert(index >= 0 && index < self.buffers.count)
self.buffers[index] = buffer
while self.bufferSavedIndex < self.buffers.count, let bufferToSave = self.buffers[bufferSavedIndex] {
self.buffers[self.bufferSavedIndex] = nil
self.bufferSavedIndex += 1
try await self.outputStream(bufferToSave)
}
}
}
let partSize64: Int64 = numericCast(partSize)
var count = 0
var offset: Int64 = 0
let numberOfParts: Int = numericCast((contentLength - 1) / partSize64) + 1
var downloadBuffers = DownloadedBuffers(numberOfBuffers: numberOfParts) { buffer in
try await outputStream(buffer, contentLength)
}
// while we still have parts to download
while count < numberOfParts {
if count > concurrentDownloads {
// if count is greater than concurrentDownloads then start waiting for
// parts that have downloaded to save them
if let (index, buffer) = try await group.next() {
// save the buffer
try await downloadBuffers.saveBuffer(index: index, buffer: buffer)
}
}
let index = count
let currentPartSize = min(partSize64, contentLength - offset)
let currentOffset = offset
// add task downloading from S3
group.addTask {
let range = "bytes=\(currentOffset)-\(currentOffset + currentPartSize - 1)"
let getRequest = S3.GetObjectRequest(
bucket: input.bucket,
key: input.key,
range: range,
sseCustomerAlgorithm: input.sseCustomerAlgorithm,
sseCustomerKey: input.sseCustomerKey,
sseCustomerKeyMD5: input.sseCustomerKeyMD5,
versionId: input.versionId
)
let getObjectOutput = try await getObject(getRequest, logger: logger)
let buffer = try await getObjectOutput.body.collect(upTo: partSize)
return (index, buffer)
}
offset += partSize64
count += 1
}

// save part
try await savePart(downloadedPart: downloadedPart)
// save the remaining parts
for try await(index, buffer) in group {
// save the buffer
try await downloadBuffers.saveBuffer(index: index, buffer: buffer)
}
}
// wait for last download
let downloadedPart = try await downloadedPartTask.value
// and save part
try await savePart(downloadedPart: downloadedPart)

return contentLength
}

Expand All @@ -128,6 +159,7 @@ extension S3 {
/// - input: The GetObjectRequest shape that contains the details of the object request.
/// - partSize: Size of each part to be downloaded
/// - filename: Filename to save download to
/// - concurrentDownloads: How many downloads can you have running at one time
/// - threadPool: Thread pool used to save file
/// - logger: logger
/// - progress: Callback that returns the progress of the download. It is called after each part is downloaded with a value
Expand All @@ -137,9 +169,10 @@ extension S3 {
_ input: GetObjectRequest,
partSize: Int = 5 * 1024 * 1024,
filename: String,
concurrentDownloads: Int = 4,
threadPool: NIOThreadPool = .singleton,
logger: Logger = AWSClient.loggingDisabled,
progress: @escaping (Double) async throws -> Void = { _ in }
progress: @escaping @Sendable (Double) async throws -> Void = { _ in }
) async throws -> Int64 {
let eventLoop = self.client.eventLoopGroup.any()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
Expand All @@ -151,6 +184,7 @@ extension S3 {
downloaded = try await self.multipartDownload(
input,
partSize: partSize,
concurrentDownloads: concurrentDownloads,
logger: logger
) { byteBuffer, fileSize in
let bufferSize = byteBuffer.readableBytes
Expand Down
Loading