diff --git a/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java b/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java index 0cb99188d37d..68b8dbcbe2be 100644 --- a/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java +++ b/spring-web/src/main/java/org/springframework/http/client/OutputStreamPublisher.java @@ -21,13 +21,13 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.Objects; -import java.util.concurrent.Exchanger; import java.util.concurrent.Executor; import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.concurrent.locks.LockSupport; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -40,11 +40,6 @@ */ final class OutputStreamPublisher implements Flow.Publisher { - private static final ByteBuffer CLOSED = ByteBuffer.allocate(0); - - private static final ByteBuffer CANCELED = ByteBuffer.allocate(0); - - private final OutputStreamHandler outputStreamHandler; private final Executor executor; @@ -94,7 +89,9 @@ public static Flow.Publisher create(OutputStreamHandler outputStream public void subscribe(Flow.Subscriber subscriber) { Objects.requireNonNull(subscriber, "Subscriber must not be null"); - subscriber.onSubscribe(new OutputStreamSubscription(subscriber, this.outputStreamHandler, this.executor)); + OutputStreamSubscription subscription = new OutputStreamSubscription(subscriber, this.outputStreamHandler); + subscriber.onSubscribe(subscription); + this.executor.execute(subscription::invokeHandler); } @@ -130,199 +127,284 @@ public interface OutputStreamHandler { } - private static final class OutputStreamSubscription implements Flow.Subscription { + private static final class OutputStreamSubscription extends OutputStream implements Flow.Subscription { - private final Flow.Subscriber subscriber; + static final Object READY = new Object(); + + private final Flow.Subscriber actual; private final OutputStreamHandler outputStreamHandler; - private final Executor executor; - private final AtomicBoolean handlerInvoked = new AtomicBoolean(); + @Nullable + private volatile Throwable error; - private final AtomicLong demand = new AtomicLong(); + private volatile long requested; + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(OutputStreamSubscription.class, "requested"); - private final Exchanger exchanger = new Exchanger<>(); + @Nullable + private volatile Object parkedThread; + static final AtomicReferenceFieldUpdater PARKED_THREAD = + AtomicReferenceFieldUpdater.newUpdater(OutputStreamSubscription.class, Object.class, "parkedThread"); - private volatile boolean canceled = false; + long produced; - public OutputStreamSubscription(Flow.Subscriber subscriber, - OutputStreamHandler outputStreamHandler, - Executor executor) { - this.subscriber = subscriber; + public OutputStreamSubscription(Flow.Subscriber actual, + OutputStreamHandler outputStreamHandler) { + this.actual = actual; this.outputStreamHandler = outputStreamHandler; - this.executor = executor; } - @Override - public void request(long n) { - Assert.isTrue(n > 0, "request should be a positive number"); - - long prev = this.demand.getAndAccumulate(n, (cur, giv) -> { - long sum = cur + giv; - return sum < 0 ? Long.MAX_VALUE : sum; - }); - if (this.handlerInvoked.compareAndSet(false, true)) { - this.executor.execute(this::invokeHandler); - } - if (prev == 0) { - exchangeBuffer(); - } - } + public void write(int b) throws IOException { + checkDemandAndAwaitIfNeeded(); - private void invokeHandler() { - // use BufferedOutputStream, so that written bytes are buffered - // before publishing as byte buffer - try (OutputStream outputStream = new BufferedOutputStream( - new ExchangerOutputStream(this.exchanger, this.subscriber))) { + ByteBuffer byteBuffer = ByteBuffer.allocate(1); + byteBuffer.put((byte) b); + byteBuffer.flip(); - this.outputStreamHandler.handle(outputStream); - } - catch (IOException ex) { - if (!this.canceled) { - this.subscriber.onError(ex); - } - } + this.actual.onNext(byteBuffer); + + this.produced++; } @Override - public void cancel() { - this.canceled = true; + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + checkDemandAndAwaitIfNeeded(); + + ByteBuffer byteBuffer = ByteBuffer.allocate(len); + byteBuffer.put(b, off, len); + byteBuffer.flip(); + + this.actual.onNext(byteBuffer); + + this.produced++; } - private void exchangeBuffer() { - long demand = this.demand.get(); - try { - while (demand > 0 && !this.canceled) { - ByteBuffer byteBuffer = this.exchanger.exchange(null); - if (byteBuffer != CLOSED) { - demand = publishBuffer(byteBuffer); + private void checkDemandAndAwaitIfNeeded() throws IOException { + long r = this.requested; + + if (isTerminated(r) || isCancelled(r)) { + throw new IOException("Subscription has been terminated"); + } + + long p = this.produced; + if (p == r) { + if (p > 0) { + r = tryProduce(p); + this.produced = 0; + } + + for (;;) { + if (isTerminated(r) || isCancelled(r)) { + throw new IOException("Subscription has been terminated"); } - else { - this.subscriber.onComplete(); - demand = 0; + + if (r != 0) { + return; } + + await(); + + r = this.requested; } - if (this.canceled) { - this.exchanger.exchange(CANCELED); - } - } - catch (InterruptedException ex) { - this.subscriber.onError(ex); } } - private long publishBuffer(ByteBuffer byteBuffer) { - this.subscriber.onNext(byteBuffer); - return this.demand.decrementAndGet(); + @Override + public void close() { } - } + private void invokeHandler() { + // assume sync write within try-with-resource block - private static final class ExchangerOutputStream extends OutputStream { + // use BufferedOutputStream, so that written bytes are buffered + // before publishing as byte buffer + try (OutputStream outputStream = new BufferedOutputStream(this)) { + this.outputStreamHandler.handle(outputStream); + } + catch (IOException ex) { + long previousState = tryTerminate(); + if (isCancelled(previousState)) { + return; + } - private final AtomicReference state = new AtomicReference<>(State.OPEN); + if (isTerminated(previousState)) { + // failure due to illegal requestN + this.actual.onError(this.error); + return; + } - private final Exchanger exchanger; + this.actual.onError(ex); + return; + } - private final Flow.Subscriber subscriber; + long previousState = tryTerminate(); + if (isCancelled(previousState)) { + return; + } + if (isTerminated(previousState)) { + // failure due to illegal requestN + this.actual.onError(this.error); + return; + } - public ExchangerOutputStream(Exchanger exchanger, Flow.Subscriber subscriber) { - this.exchanger = exchanger; - this.subscriber = subscriber; + this.actual.onComplete(); } @Override - public void write(int b) throws IOException { - this.state.get().write((byte) b, this); - } + public void request(long n) { + if (n <= 0) { + this.error = new IllegalArgumentException("request should be a positive number"); + long previousState = tryTerminate(); - @Override - public void write(byte[] b) throws IOException { - this.state.get().write(b, 0, b.length, this); + if (isTerminated(previousState) || isCancelled(previousState)) { + return; + } + + if (previousState > 0) { + // error should eventually be observed and propagated + return; + } + + // resume parked thread so it can observe error and propagate it + resume(); + return; + } + + if (addCap(n) == 0) { + // resume parked thread so it can continue the work + resume(); + } } @Override - public void write(byte[] b, int off, int len) throws IOException { - this.state.get().write(b, off, len, this); + public void cancel() { + long previousState = tryCancel(); + if (isCancelled(previousState) || previousState > 0) { + return; + } + + // resume parked thread so it can be unblocked and close all the resources + resume(); } - private void exchange(ByteBuffer byteBuffer) { - try { - ByteBuffer result = this.exchanger.exchange(byteBuffer); - if (result == CANCELED) { - this.state.compareAndSet(State.OPEN, State.CANCELED); + private void await() { + Thread toUnpark = Thread.currentThread(); + + for (;;) { + Object current = parkedThread; + if (current == READY) { + break; + } + + if (current != null && current != toUnpark) { + throw new IllegalStateException("Only one (Virtual)Thread can await!"); + } + + if (PARKED_THREAD.compareAndSet(this, null, toUnpark)) { + LockSupport.park(); + // we don't just break here because park() can wake up spuriously + // if we got a proper resume, get() == READY and the loop will quit above } } - catch (InterruptedException ex) { - this.subscriber.onError(ex); + // clear the resume indicator so that the next await call will park without a resume() + PARKED_THREAD.lazySet(this, null); + } + + private void resume() { + if (parkedThread != READY) { + Object old = PARKED_THREAD.getAndSet(this, READY); + if (old != READY) { + LockSupport.unpark((Thread)old); + } } } - @Override - public void close() { - if (this.state.compareAndSet(State.OPEN, State.CLOSED)) { - try { - this.exchanger.exchange(CLOSED); + private long tryCancel() { + for (;;) { + long r = requested; + + if (isCancelled(r)) { + return r; } - catch (InterruptedException ex) { - this.subscriber.onError(ex); + + if (REQUESTED.compareAndSet(this, r, Long.MIN_VALUE)) { + return r; } } } - private enum State { + private long tryTerminate() { + for (;;) { + long r = requested; - OPEN { - @Override - public void write(byte b, ExchangerOutputStream wrapper) throws IOException { - ByteBuffer byteBuffer = ByteBuffer.allocate(1); - byteBuffer.put(b); - byteBuffer.flip(); - wrapper.exchange(byteBuffer); + if (isCancelled(r) || isTerminated(r)) { + return r; } - @Override - public void write(byte[] b, int off, int len, ExchangerOutputStream wrapper) throws IOException { - ByteBuffer byteBuffer = ByteBuffer.allocate(len); - byteBuffer.put(b, off, len); - byteBuffer.flip(); - wrapper.exchange(byteBuffer); - } - }, CLOSED { - @Override - public void write(byte b, ExchangerOutputStream wrapper) throws IOException { - throw new IOException("Stream closed"); + if (REQUESTED.compareAndSet(this, r, Long.MIN_VALUE | Long.MAX_VALUE)) { + return r; } + } + } - @Override - public void write(byte[] bytes, int off, int len, ExchangerOutputStream wrapper) throws IOException { - throw new IOException("Stream closed"); + private long tryProduce(long n) { + for (; ; ) { + long current = this.requested; + if (isTerminated(current) || isCancelled(current)) { + return current; + } + if (current == Long.MAX_VALUE) { + return Long.MAX_VALUE; } - }, CANCELED { - @Override - public void write(byte b, ExchangerOutputStream wrapper) throws IOException { - throw new IOException("Subscription has been cancelled"); + long update = current - n; + if (update < 0L) { + update = 0L; } + if (REQUESTED.compareAndSet(this, current, update)) { + return update; + } + } + } - @Override - public void write(byte[] bytes, int off, int len, ExchangerOutputStream wrapper) throws IOException { - throw new IOException("Subscription has been cancelled"); + private long addCap(long n) { + for (; ; ) { + long r = this.requested; + if (isTerminated(r) || isCancelled(r) || r == Long.MAX_VALUE) { + return r; } - }; + long u = addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + return r; + } + } + } - public abstract void write(byte b, ExchangerOutputStream wrapper) throws IOException; + static boolean isTerminated(long state) { + return state == (Long.MIN_VALUE | Long.MAX_VALUE); + } - public abstract void write(byte[] bytes, int off, int len, ExchangerOutputStream wrapper) throws IOException; + static boolean isCancelled(long state) { + return state == Long.MIN_VALUE; + } + static long addCap(long a, long b) { + long res = a + b; + if (res < 0L) { + return Long.MAX_VALUE; + } + return res; } } - - - } diff --git a/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java b/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java index fb928dd5daf4..57d4b993275a 100644 --- a/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/OutputStreamPublisherTests.java @@ -83,15 +83,14 @@ void cancel() throws InterruptedException { Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { try (Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { - writer.write("foo"); - writer.flush(); - writer.write("bar"); - writer.flush(); - assertThatIOException().isThrownBy(() -> { - writer.write("baz"); + assertThatIOException() + .isThrownBy(() -> { + writer.write("foo"); + writer.flush(); + writer.write("bar"); writer.flush(); }) - .withMessage("Subscription has been cancelled"); + .withMessage("Subscription has been terminated"); latch.countDown(); } }, this.executor); @@ -126,6 +125,54 @@ void closed() throws InterruptedException { latch.await(); } + @Test + void negativeRequestN() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + + Flow.Publisher flowPublisher = OutputStreamPublisher.create(outputStream -> { + try(Writer writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { + writer.write("foo"); + writer.flush(); + writer.write("foo"); + writer.flush(); + } + finally { + latch.countDown(); + } + }, this.executor); + Flow.Subscription[] subscriptions = new Flow.Subscription[1]; + Flux flux = toString((a) -> flowPublisher.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptions[0] = subscription; + a.onSubscribe(subscription); + } + + @Override + public void onNext(ByteBuffer item) { + a.onNext(item); + } + + @Override + public void onError(Throwable throwable) { + a.onError(throwable); + } + + @Override + public void onComplete() { + a.onComplete(); + } + })); + + StepVerifier.create(flux, 1) + .assertNext(s -> assertThat(s).isEqualTo("foo")) + .then(() -> subscriptions[0].request(-1)) + .expectErrorMessage("request should be a positive number") + .verify(); + + latch.await(); + } + private static Flux toString(Flow.Publisher flowPublisher) { return Flux.from(FlowAdapters.toPublisher(flowPublisher)) .map(bb -> StandardCharsets.UTF_8.decode(bb).toString());