diff --git a/Sources/Soto/Extensions/S3/S3+multipart.swift b/Sources/Soto/Extensions/S3/S3+multipart.swift index 33be3dff27..dff65c958d 100644 --- a/Sources/Soto/Extensions/S3/S3+multipart.swift +++ b/Sources/Soto/Extensions/S3/S3+multipart.swift @@ -14,6 +14,7 @@ import Atomics import Logging +import NIOConcurrencyHelpers import NIOCore import NIOPosix import SotoCore @@ -50,13 +51,15 @@ extension S3 { /// - parameters: /// - input: The GetObjectRequest shape that contains the details of the object request. /// - partSize: Size of each part to be downloaded + /// - concurrentDownloads: How many downloads can you have running at one time /// - outputStream: Function to be called for each downloaded part. Called with data block and file size /// - returns: The complete file size once the multipart download has finished. public func multipartDownload( _ input: GetObjectRequest, partSize: Int = 5 * 1024 * 1024, + concurrentDownloads: Int = 4, logger: Logger = AWSClient.loggingDisabled, - outputStream: @escaping (ByteBuffer, Int64) async throws -> Void + outputStream: @escaping @Sendable (ByteBuffer, Int64) async throws -> Void ) async throws -> Int64 { // get object size before downloading let headRequest = S3.HeadObjectRequest( @@ -77,48 +80,76 @@ extension S3 { throw S3ErrorType.multipart.downloadEmpty(message: "Content length is unexpectedly zero") } - // download part task - func downloadPartTask(offset: Int64, partSize: Int64) -> Task { - let range = "bytes=\(offset)-\(offset + Int64(partSize - 1))" - let getRequest = S3.GetObjectRequest( - bucket: input.bucket, - key: input.key, - range: range, - sseCustomerAlgorithm: input.sseCustomerAlgorithm, - sseCustomerKey: input.sseCustomerKey, - sseCustomerKeyMD5: input.sseCustomerKeyMD5, - versionId: input.versionId - ) - return Task { - try await getObject(getRequest, logger: logger) - } - } - - // save part task - func savePart(downloadedPart: GetObjectOutput) async throws { - try await outputStream(downloadedPart.body.collect(upTo: .max), contentLength) - } - - let partSize: Int64 = numericCast(partSize) - var offset = min(partSize, contentLength) - var downloadedPartTask = downloadPartTask(offset: 0, partSize: offset) - while offset < contentLength { - // wait for previous download - let downloadedPart = try await downloadedPartTask.value + try await withThrowingTaskGroup(of: (Int, ByteBuffer).self) { group in + /// Structure used to store downloaded buffers and then save them as and when + /// needed + struct DownloadedBuffers { + let outputStream: @Sendable (ByteBuffer) async throws -> Void + var buffers: [ByteBuffer?] + var bufferSavedIndex: Int + + init(numberOfBuffers: Int, outputStream: @escaping @Sendable (ByteBuffer) async throws -> Void) { + self.outputStream = outputStream + self.buffers = Array(repeating: nil, count: numberOfBuffers) + self.bufferSavedIndex = 0 + } - // start next download - let downloadPartSize = min(partSize, contentLength - offset) - downloadedPartTask = downloadPartTask(offset: offset, partSize: downloadPartSize) - offset += downloadPartSize + mutating func saveBuffer(index: Int, buffer: ByteBuffer) async throws { + assert(index >= 0 && index < self.buffers.count) + self.buffers[index] = buffer + while self.bufferSavedIndex < self.buffers.count, let bufferToSave = self.buffers[bufferSavedIndex] { + self.buffers[self.bufferSavedIndex] = nil + self.bufferSavedIndex += 1 + try await self.outputStream(bufferToSave) + } + } + } + let partSize64: Int64 = numericCast(partSize) + var count = 0 + var offset: Int64 = 0 + let numberOfParts: Int = numericCast((contentLength - 1) / partSize64) + 1 + var downloadBuffers = DownloadedBuffers(numberOfBuffers: numberOfParts) { buffer in + try await outputStream(buffer, contentLength) + } + // while we still have parts to download + while count < numberOfParts { + if count > concurrentDownloads { + // if count is greater than concurrentDownloads then start waiting for + // parts that have downloaded to save them + if let (index, buffer) = try await group.next() { + // save the buffer + try await downloadBuffers.saveBuffer(index: index, buffer: buffer) + } + } + let index = count + let currentPartSize = min(partSize64, contentLength - offset) + let currentOffset = offset + // add task downloading from S3 + group.addTask { + let range = "bytes=\(currentOffset)-\(currentOffset + currentPartSize - 1)" + let getRequest = S3.GetObjectRequest( + bucket: input.bucket, + key: input.key, + range: range, + sseCustomerAlgorithm: input.sseCustomerAlgorithm, + sseCustomerKey: input.sseCustomerKey, + sseCustomerKeyMD5: input.sseCustomerKeyMD5, + versionId: input.versionId + ) + let getObjectOutput = try await getObject(getRequest, logger: logger) + let buffer = try await getObjectOutput.body.collect(upTo: partSize) + return (index, buffer) + } + offset += partSize64 + count += 1 + } - // save part - try await savePart(downloadedPart: downloadedPart) + // save the remaining parts + for try await(index, buffer) in group { + // save the buffer + try await downloadBuffers.saveBuffer(index: index, buffer: buffer) + } } - // wait for last download - let downloadedPart = try await downloadedPartTask.value - // and save part - try await savePart(downloadedPart: downloadedPart) - return contentLength } @@ -128,6 +159,7 @@ extension S3 { /// - input: The GetObjectRequest shape that contains the details of the object request. /// - partSize: Size of each part to be downloaded /// - filename: Filename to save download to + /// - concurrentDownloads: How many downloads can you have running at one time /// - threadPool: Thread pool used to save file /// - logger: logger /// - progress: Callback that returns the progress of the download. It is called after each part is downloaded with a value @@ -137,9 +169,10 @@ extension S3 { _ input: GetObjectRequest, partSize: Int = 5 * 1024 * 1024, filename: String, + concurrentDownloads: Int = 4, threadPool: NIOThreadPool = .singleton, logger: Logger = AWSClient.loggingDisabled, - progress: @escaping (Double) async throws -> Void = { _ in } + progress: @escaping @Sendable (Double) async throws -> Void = { _ in } ) async throws -> Int64 { let eventLoop = self.client.eventLoopGroup.any() let fileIO = NonBlockingFileIO(threadPool: threadPool) @@ -151,6 +184,7 @@ extension S3 { downloaded = try await self.multipartDownload( input, partSize: partSize, + concurrentDownloads: concurrentDownloads, logger: logger ) { byteBuffer, fileSize in let bufferSize = byteBuffer.readableBytes