diff --git a/sdk/core/azure-core-http-netty/src/test/java/com/azure/core/http/netty/NettyFluxTests.java b/sdk/core/azure-core-http-netty/src/test/java/com/azure/core/http/netty/NettyFluxTests.java index 0c9b68f528b1c..76211f06eec3e 100644 --- a/sdk/core/azure-core-http-netty/src/test/java/com/azure/core/http/netty/NettyFluxTests.java +++ b/sdk/core/azure-core-http-netty/src/test/java/com/azure/core/http/netty/NettyFluxTests.java @@ -106,65 +106,60 @@ public void testCanReadEmptyFile() throws IOException { @Test public void testAsynchronyShortInput() throws IOException { -// File file = createFileIfNotExist("target/test3"); -// FileOutputStream stream = new FileOutputStream(file); -// stream.write("hello there".getBytes(StandardCharsets.UTF_8)); -// stream.close(); -// try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(file.toPath(), StandardOpenOption.READ)) { -// byte[] bytes = FluxUtil.byteBufStreamFromFile(channel) -// .map(bb -> { -// byte[] bt = bb.array(); -// ReferenceCountUtil.release(bb); -// return bt; -// }) -// .limitRequest(1) -// .subscribeOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) -// .publishOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) -// .collect(() -> new ByteArrayOutputStream(), -// (bos, b) -> { -// try { -// bos.write(b); -// } catch (IOException ioe) { -// throw Exceptions.propagate(ioe); -// } -// }) -// .block() -// .toByteArray(); -// assertEquals("hello there", new String(bytes, StandardCharsets.UTF_8)); -// } -// assertTrue(file.delete()); - Assert.fail("Need to implement this test again"); + File file = createFileIfNotExist("target/test3"); + FileOutputStream stream = new FileOutputStream(file); + stream.write("hello there".getBytes(StandardCharsets.UTF_8)); + stream.close(); + try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(file.toPath(), StandardOpenOption.READ)) { + byte[] bytes = FluxUtil.readFile(channel) + .map(bb -> { + byte[] bt = new byte[bb.remaining()]; + bb.get(bt); + return bt; + }) + .limitRequest(1) + .subscribeOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) + .publishOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) + .collect(() -> new ByteArrayOutputStream(), + (bos, b) -> { + try { + bos.write(b); + } catch (IOException ioe) { + throw Exceptions.propagate(ioe); + } + }) + .block() + .toByteArray(); + assertEquals("hello there", new String(bytes, StandardCharsets.UTF_8)); + } + assertTrue(file.delete()); } private static final int NUM_CHUNKS_IN_LONG_INPUT = 10_000_000; @Test public void testAsynchronyLongInput() throws IOException, NoSuchAlgorithmException { -// File file = createFileIfNotExist("target/test4"); -// byte[] array = "1234567690".getBytes(StandardCharsets.UTF_8); -// MessageDigest digest = MessageDigest.getInstance("MD5"); -// try (BufferedOutputStream out = new BufferedOutputStream(new FileOutputStream(file))) { -// for (int i = 0; i < NUM_CHUNKS_IN_LONG_INPUT; i++) { -// out.write(array); -// digest.update(array); -// } -// } -// System.out.println("long input file size=" + file.length() / (1024 * 1024) + "MB"); -// byte[] expected = digest.digest(); -// digest.reset(); -// try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(file.toPath(), StandardOpenOption.READ)) { -// FluxUtil.byteBufStreamFromFile(channel) -// .subscribeOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) -// .publishOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) -// .toIterable().forEach(bb -> { -// digest.update(bb); -// ReferenceCountUtil.release(bb); -// }); -// -// assertArrayEquals(expected, digest.digest()); -// } -// assertTrue(file.delete()); - Assert.fail("Need to implement this test again"); + File file = createFileIfNotExist("target/test4"); + byte[] array = "1234567690".getBytes(StandardCharsets.UTF_8); + MessageDigest digest = MessageDigest.getInstance("MD5"); + try (BufferedOutputStream out = new BufferedOutputStream(new FileOutputStream(file))) { + for (int i = 0; i < NUM_CHUNKS_IN_LONG_INPUT; i++) { + out.write(array); + digest.update(array); + } + } + System.out.println("long input file size=" + file.length() / (1024 * 1024) + "MB"); + byte[] expected = digest.digest(); + digest.reset(); + try (AsynchronousFileChannel channel = AsynchronousFileChannel.open(file.toPath(), StandardOpenOption.READ)) { + FluxUtil.readFile(channel) + .subscribeOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) + .publishOn(reactor.core.scheduler.Schedulers.newElastic("io", 30)) + .toIterable().forEach(digest::update); + + assertArrayEquals(expected, digest.digest()); + } + assertTrue(file.delete()); } @Test diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java index e7fb53df06383..407e41149e3f4 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/http/policy/HttpLoggingPolicy.java @@ -90,7 +90,7 @@ private Mono logRequest(final ClientLogger logger, final HttpRequest reque if (contentLength < MAX_BODY_LOG_SIZE && isHumanReadableContentType) { try { - Mono collectedBytes = FluxUtil.collectBytesInByteBufferStream(request.body(), true); + Mono collectedBytes = FluxUtil.collectBytesInByteBufferStream(request.body()); reqBodyLoggingMono = collectedBytes.flatMap(bytes -> { String bodyString = new String(bytes, StandardCharsets.UTF_8); bodyString = prettyPrintIfNeeded(logger, request.headers().value("Content-Type"), bodyString); diff --git a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/util/FluxUtil.java b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/util/FluxUtil.java index c983e15d69ca9..7b8562677ab2b 100644 --- a/sdk/core/azure-core/src/main/java/com/azure/core/implementation/util/FluxUtil.java +++ b/sdk/core/azure-core/src/main/java/com/azure/core/implementation/util/FluxUtil.java @@ -4,6 +4,12 @@ package com.azure.core.implementation.util; import com.azure.core.util.Context; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -18,14 +24,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoSink; -import reactor.core.publisher.Operators; - /** * Utility type exposing methods to deal with {@link Flux}. */ @@ -49,27 +47,9 @@ public static boolean isFluxByteBuffer(Type entityType) { /** * Collects ByteBuffer emitted by a Flux into a byte array. * @param stream A stream which emits ByteBuf instances. - * @param autoReleaseEnabled if ByteBuffer instances in stream gets automatically released as they consumed * @return A Mono which emits the concatenation of all the ByteBuf instances given by the source Flux. */ - public static Mono collectBytesInByteBufferStream(Flux stream, boolean autoReleaseEnabled) { -// if (autoReleaseEnabled) { -// // A stream is auto-release enabled means - the ByteBuf chunks in the stream get -// // released as consumer consumes each chunk. -// return Mono.using(Unpooled::compositeBuffer, -// cbb -> stream.collect(() -> cbb, -// (cbb1, buffer) -> cbb1.addComponent(true, Unpooled.wrappedBuffer(buffer).retain())), -// ReferenceCountUtil::release) -// .filter((CompositeByteBuf cbb) -> cbb.isReadable()) -// .map(FluxUtil::byteBufferToArray); -// } else { -// return stream.collect(Unpooled::compositeBuffer, -// (cbb1, buffer) -> cbb1.addComponent(true, Unpooled.wrappedBuffer(buffer))) -// .filter((CompositeByteBuf cbb) -> cbb.isReadable()) -// .map(FluxUtil::byteBufferToArray); -// } - - // TODO this is not a good implementation + public static Mono collectBytesInByteBufferStream(Flux stream) { return stream .collect(ByteArrayOutputStream::new, FluxUtil::accept) .map(ByteArrayOutputStream::toByteArray); @@ -77,9 +57,9 @@ public static Mono collectBytesInByteBufferStream(Flux strea private static void accept(ByteArrayOutputStream byteOutputStream, ByteBuffer byteBuffer) { try { - byteOutputStream.write(byteBuffer.array()); + byteOutputStream.write(byteBufferToArray(byteBuffer)); } catch (IOException e) { - e.printStackTrace(); + throw new RuntimeException(e); } } @@ -89,17 +69,14 @@ private static void accept(ByteArrayOutputStream byteOutputStream, ByteBuffer by * have optionally backing array. * * - * @param byteBuf the byte buffer + * @param byteBuffer the byte buffer * @return the byte array */ - public static byte[] byteBufferToArray(ByteBuffer byteBuf) { -// int length = byteBuf.readableBytes(); -// byte[] byteArray = new byte[length]; -// byteBuf.getBytes(byteBuf.readerIndex(), byteArray); -// return byteArray; - - // FIXME this is not good code! - return byteBuf.array(); + public static byte[] byteBufferToArray(ByteBuffer byteBuffer) { + int length = byteBuffer.remaining(); + byte[] byteArray = new byte[length]; + byteBuffer.get(byteArray); + return byteArray; } /** @@ -159,6 +136,314 @@ private static Context toAzureContext(reactor.util.context.Context context) { return Context.of(keyValues); } + /** + * Writes the bytes emitted by a Flux to an AsynchronousFileChannel. + * + * @param content the Flux content + * @param outFile the file channel + * @return a Mono which performs the write operation when subscribed + */ + public static Mono writeFile(Flux content, AsynchronousFileChannel outFile) { + return writeFile(content, outFile, 0); + } + + /** + * Writes the bytes emitted by a Flux to an AsynchronousFileChannel + * starting at the given position in the file. + * + * @param content the Flux content + * @param outFile the file channel + * @param position the position in the file to begin writing + * @return a Mono which performs the write operation when subscribed + */ + public static Mono writeFile(Flux content, AsynchronousFileChannel outFile, long position) { + return Mono.create(emitter -> content.subscribe(new Subscriber() { + // volatile ensures that writes to these fields by one thread will be immediately visible to other threads. + // An I/O pool thread will write to isWriting and read isCompleted, + // while another thread may read isWriting and write to isCompleted. + volatile boolean isWriting = false; + volatile boolean isCompleted = false; + volatile Subscription subscription; + volatile long pos = position; + + @Override + public void onSubscribe(Subscription s) { + subscription = s; + s.request(1); + } + + @Override + public void onNext(ByteBuffer bytes) { + isWriting = true; + outFile.write(bytes, pos, null, onWriteCompleted); + } + + + CompletionHandler onWriteCompleted = new CompletionHandler() { + @Override + public void completed(Integer bytesWritten, Object attachment) { + isWriting = false; + if (isCompleted) { + emitter.success(); + } + //noinspection NonAtomicOperationOnVolatileField + pos += bytesWritten; + subscription.request(1); + } + + @Override + public void failed(Throwable exc, Object attachment) { + subscription.cancel(); + emitter.error(exc); + } + }; + + @Override + public void onError(Throwable throwable) { + subscription.cancel(); + emitter.error(throwable); + } + + @Override + public void onComplete() { + isCompleted = true; + if (!isWriting) { + emitter.success(); + } + } + })); + } + + /** + * Creates a {@link Flux} from an {@link AsynchronousFileChannel} + * which reads part of a file into chunks of the given size. + * + * @param fileChannel The file channel. + * @param chunkSize the size of file chunks to read. + * @param offset The offset in the file to begin reading. + * @param length The number of bytes to read from the file. + * @return the Flux. + */ + public static Flux readFile(AsynchronousFileChannel fileChannel, int chunkSize, long offset, long length) { + return new FileReadFlux(fileChannel, chunkSize, offset, length); + } + + /** + * Creates a {@link Flux} from an {@link AsynchronousFileChannel} + * which reads part of a file. + * + * @param fileChannel The file channel. + * @param offset The offset in the file to begin reading. + * @param length The number of bytes to read from the file. + * @return the Flux. + */ + public static Flux readFile(AsynchronousFileChannel fileChannel, long offset, long length) { + return readFile(fileChannel, DEFAULT_CHUNK_SIZE, offset, length); + } + + /** + * Creates a {@link Flux} from an {@link AsynchronousFileChannel} + * which reads the entire file. + * + * @param fileChannel The file channel. + * @return The AsyncInputStream. + */ + public static Flux readFile(AsynchronousFileChannel fileChannel) { + try { + long size = fileChannel.size(); + return readFile(fileChannel, DEFAULT_CHUNK_SIZE, 0, size); + } catch (IOException e) { + return Flux.error(e); + } + } + + private static final int DEFAULT_CHUNK_SIZE = 1024 * 64; + + private static final class FileReadFlux extends Flux { + private final AsynchronousFileChannel fileChannel; + private final int chunkSize; + private final long offset; + private final long length; + + FileReadFlux(AsynchronousFileChannel fileChannel, int chunkSize, long offset, long length) { + this.fileChannel = fileChannel; + this.chunkSize = chunkSize; + this.offset = offset; + this.length = length; + } + + @Override + public void subscribe(CoreSubscriber actual) { + FileReadSubscription subscription = new FileReadSubscription(actual, fileChannel, chunkSize, offset, length); + actual.onSubscribe(subscription); + } + + static final class FileReadSubscription implements Subscription, CompletionHandler { + private static final int NOT_SET = -1; + private static final long serialVersionUID = -6831808726875304256L; + // + private final Subscriber subscriber; + private volatile long position; + // + private final AsynchronousFileChannel fileChannel; + private final int chunkSize; + private final long offset; + private final long length; + // + private volatile boolean done; + private Throwable error; + private volatile ByteBuffer next; + private volatile boolean cancelled; + // + volatile int wip; + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(FileReadSubscription.class, "wip"); + volatile long requested; + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(FileReadSubscription.class, "requested"); + // + + FileReadSubscription(Subscriber subscriber, AsynchronousFileChannel fileChannel, int chunkSize, long offset, long length) { + this.subscriber = subscriber; + // + this.fileChannel = fileChannel; + this.chunkSize = chunkSize; + this.offset = offset; + this.length = length; + // + this.position = NOT_SET; + } + + //region Subscription implementation + + @Override + public void request(long n) { + if (Operators.validate(n)) { + Operators.addCap(REQUESTED, this, n); + drain(); + } + } + + @Override + public void cancel() { + this.cancelled = true; + } + + //endregion + + //region CompletionHandler implementation + + @Override + public void completed(Integer bytesRead, ByteBuffer buffer) { + if (!cancelled) { + if (bytesRead == -1) { + done = true; + } else { + // use local variable to perform fewer volatile reads + long pos = position; + int bytesWanted = (int) Math.min(bytesRead, maxRequired(pos)); + long position2 = pos + bytesWanted; + //noinspection NonAtomicOperationOnVolatileField + position = position2; + buffer.position(bytesWanted); + buffer.flip(); + next = buffer; + if (position2 >= offset + length) { + done = true; + } + } + drain(); + } + } + + @Override + public void failed(Throwable exc, ByteBuffer attachment) { + if (!cancelled) { + // must set error before setting done to true + // so that is visible in drain loop + error = exc; + done = true; + drain(); + } + } + + //endregion + + private void drain() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + // on first drain (first request) we initiate the first read + if (position == NOT_SET) { + position = offset; + doRead(); + } + int missed = 1; + for (;;) { + if (cancelled) { + return; + } + if (REQUESTED.get(this) > 0) { + boolean emitted = false; + // read d before next to avoid race + boolean d = done; + ByteBuffer bb = next; + if (bb != null) { + next = null; + subscriber.onNext(bb); + emitted = true; + } else { + emitted = false; + } + if (d) { + if (error != null) { + subscriber.onError(error); + // exit without reducing wip so that further drains will be NOOP + return; + } else { + subscriber.onComplete(); + // exit without reducing wip so that further drains will be NOOP + return; + } + } + if (emitted) { + // do this after checking d to avoid calling read + // when done + Operators.produced(REQUESTED, this, 1); + // + doRead(); + } + } + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return; + } + } + } + + private void doRead() { + // use local variable to limit volatile reads + long pos = position; + ByteBuffer innerBuf = ByteBuffer.allocate(Math.min(chunkSize, maxRequired(pos))); + fileChannel.read(innerBuf, pos, innerBuf, this); + } + + private int maxRequired(long pos) { + long maxRequired = offset + length - pos; + if (maxRequired <= 0) { + return 0; + } else { + int m = (int) (maxRequired); + // support really large files by checking for overflow + if (m < 0) { + return Integer.MAX_VALUE; + } else { + return m; + } + } + } + } + } // Private Ctr diff --git a/sdk/core/azure-core/src/test/java/com/azure/core/http/MockHttpClient.java b/sdk/core/azure-core/src/test/java/com/azure/core/http/MockHttpClient.java index ec57d806252be..ed5f7e7f0986d 100644 --- a/sdk/core/azure-core/src/test/java/com/azure/core/http/MockHttpClient.java +++ b/sdk/core/azure-core/src/test/java/com/azure/core/http/MockHttpClient.java @@ -12,6 +12,7 @@ import com.azure.core.implementation.util.FluxUtil; import reactor.core.publisher.Mono; +import java.io.ByteArrayOutputStream; import java.net.URL; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import java.util.stream.Collectors; /** * This HttpClient attempts to mimic the behavior of http://httpbin.org without ever making a network call. @@ -170,13 +172,8 @@ public Mono send(HttpRequest request) { response = new MockHttpResponse(request, statusCode); } } else if ("echo.org".equalsIgnoreCase(requestHost)) { - return request.body() - .collectList() - .map(list -> { -// byte[] bytes = Unpooled.wrappedBuffer(list.toArray(new ByteBuffer[0])).array(); -// return new MockHttpResponse(request, 200, new HttpHeaders(request.headers()), bytes); - throw new IllegalStateException("Code needs to be reimplemented"); - }); + return FluxUtil.collectBytesInByteBufferStream(request.body()) + .map(bytes -> new MockHttpResponse(request, 200, new HttpHeaders(request.headers()), bytes)); } } catch (Exception ex) { return Mono.error(ex); @@ -216,7 +213,7 @@ private static String createHttpBinResponseDataForRequest(HttpRequest request) { private static String bodyToString(HttpRequest request) { String body = ""; if (request.body() != null) { - Mono asyncString = FluxUtil.collectBytesInByteBufferStream(request.body(), true) + Mono asyncString = FluxUtil.collectBytesInByteBufferStream(request.body()) .map(bytes -> new String(bytes, StandardCharsets.UTF_8)); body = asyncString.block(); } diff --git a/sdk/core/azure-core/src/test/java/com/azure/core/implementation/RestProxyXMLTests.java b/sdk/core/azure-core/src/test/java/com/azure/core/implementation/RestProxyXMLTests.java index 50a6d9ba99bc8..9f642f8617fb1 100644 --- a/sdk/core/azure-core/src/test/java/com/azure/core/implementation/RestProxyXMLTests.java +++ b/sdk/core/azure-core/src/test/java/com/azure/core/implementation/RestProxyXMLTests.java @@ -115,7 +115,7 @@ static class MockXMLReceiverClient implements HttpClient { @Override public Mono send(HttpRequest request) { if (request.url().toString().endsWith("SetContainerACLs")) { - return FluxUtil.collectBytesInByteBufferStream(request.body(), false) + return FluxUtil.collectBytesInByteBufferStream(request.body()) .map(bytes -> { receivedBytes = bytes; return new MockHttpResponse(request, 200);