diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h index d4d8c3d9f3f..f7f2e66fa07 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h @@ -8,8 +8,11 @@ #include #include -#include +#include +#include + #include +#include namespace Aws { @@ -26,8 +29,15 @@ namespace Aws StreamBufProtectedWriter() = delete; using WriterFunc = std::function; + using WriteCompleteCallback = std::function; + + static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) { + return WriteToBuffer(ioStream, writerFunc, [](uint64_t) -> void {}); + } - static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) + static uint64_t WriteToBuffer(Aws::IOStream& ioStream, + const WriterFunc& writerFunc, + const WriteCompleteCallback& writeCompleteCallback) { uint64_t totalRead = 0; @@ -53,6 +63,7 @@ namespace Aws { break; } + writeCompleteCallback(read); if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr())) { diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index cda0078abeb..7677e02052f 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -306,11 +306,6 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr& { readLimiter->ApplyAndPayForCost(read); } - auto& receivedHandler = request->GetDataReceivedEventHandler(); - if (receivedHandler) - { - receivedHandler(request.get(), response.get(), (long long)read); - } } if (!ContinueRequest(*request) || !IsRequestProcessingEnabled()) { @@ -319,7 +314,15 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr& } return connectionOpen && success && ContinueRequest(*request) && IsRequestProcessingEnabled(); }; - uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc); + uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), + writerFunc, + [&request, &response](uint64_t read) -> void { + auto& receivedHandler = request->GetDataReceivedEventHandler(); + if (receivedHandler) + { + receivedHandler(request.get(), response.get(), (long long)read); + } + }); if(!ContinueRequest(*request) || !IsRequestProcessingEnabled()) { diff --git a/tests/aws-cpp-sdk-core-tests/utils/stream/StreamBufProtectedWriterTest.cpp b/tests/aws-cpp-sdk-core-tests/utils/stream/StreamBufProtectedWriterTest.cpp new file mode 100644 index 00000000000..cc5579c7cdc --- /dev/null +++ b/tests/aws-cpp-sdk-core-tests/utils/stream/StreamBufProtectedWriterTest.cpp @@ -0,0 +1,68 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include +#include + +using namespace Aws::Utils::Stream; + +class StreamBufProtectedWriterTest : public Aws::Testing::AwsCppSdkGTestSuite {}; + +class MockedBuffer { +public: + MockedBuffer(std::initializer_list strings) : data_(strings) {} + + Aws::String read() { + if (data_.empty()) { + return {}; + } + Aws::String result = data_.front(); + data_.pop_front(); + return result; + } + +private: + Aws::Deque data_; +}; + +TEST_F(StreamBufProtectedWriterTest, ShouldBeAbleToAccessStreamAfterWriteFunction) { + Aws::StringStream bufferedStream; + Aws::StringStream output; + MockedBuffer srcBuffer({ "joker", + "skull", + "panther", + "mona"}); + Aws::String leftover{}; + StreamBufProtectedWriter::WriteToBuffer(bufferedStream, + [&srcBuffer, &leftover](char* dst, uint64_t dstSz, uint64_t& read) -> bool { + Aws::String data{}; + if (!leftover.empty()) { + data = leftover; + leftover.clear(); + } else { + data = srcBuffer.read(); + if (data.empty()) { + return false; + } + } + if (data.size() > dstSz) { + leftover = data.substr(static_cast(dstSz)); + read = dstSz; + memcpy(dst, data.c_str(), static_cast(dstSz)); + } else { + read = data.size(); + memcpy(dst, data.c_str(), static_cast(read)); + } + return true; + }, + [&bufferedStream, &output](uint64_t read) -> void { + AWS_UNREFERENCED_PARAM(read); + output << bufferedStream.rdbuf(); + }); + EXPECT_EQ(output.str(), "jokerskullpanthermona"); +}