Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/utils/Array.h>
#include <streambuf>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>

#include <functional>
#include <streambuf>

namespace Aws
{
Expand All @@ -26,8 +29,15 @@ namespace Aws
StreamBufProtectedWriter() = delete;

using WriterFunc = std::function<bool(char* dst, uint64_t dstSz, uint64_t& read)>;
using WriteCompleteCallback = std::function<void(uint64_t read)>;

static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) {
return WriteToBuffer(ioStream, writerFunc, [](uint64_t) -> void {});
}

static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc)
static uint64_t WriteToBuffer(Aws::IOStream& ioStream,
const WriterFunc& writerFunc,
const WriteCompleteCallback& writeCompleteCallback)
{
uint64_t totalRead = 0;

Expand All @@ -53,6 +63,7 @@ namespace Aws
{
break;
}
writeCompleteCallback(read);

if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr()))
{
Expand Down
15 changes: 9 additions & 6 deletions src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,6 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
if (!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
Expand All @@ -319,7 +314,15 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&
}
return connectionOpen && success && ContinueRequest(*request) && IsRequestProcessingEnabled();
};
uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc);
uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(),
writerFunc,
[&request, &response](uint64_t read) -> void {
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
});

if(!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <aws/core/utils/stream/StreamBufProtectedWriter.h>
#include <aws/core/utils/memory/stl/AWSDeque.h>
#include <aws/testing/AwsCppSdkGTestSuite.h>
#include <sstream>
#include <deque>

using namespace Aws::Utils::Stream;

class StreamBufProtectedWriterTest : public Aws::Testing::AwsCppSdkGTestSuite {};

class MockedBuffer {
public:
MockedBuffer(std::initializer_list<Aws::String> strings) : data_(strings) {}

Aws::String read() {
if (data_.empty()) {
return {};
}
Aws::String result = data_.front();
data_.pop_front();
return result;
}

private:
Aws::Deque<Aws::String> data_;
};

TEST_F(StreamBufProtectedWriterTest, ShouldBeAbleToAccessStreamAfterWriteFunction) {
Aws::StringStream bufferedStream;
Aws::StringStream output;
MockedBuffer srcBuffer({ "joker",
"skull",
"panther",
"mona"});
Aws::String leftover{};
StreamBufProtectedWriter::WriteToBuffer(bufferedStream,
[&srcBuffer, &leftover](char* dst, uint64_t dstSz, uint64_t& read) -> bool {
Aws::String data{};
if (!leftover.empty()) {
data = leftover;
leftover.clear();
} else {
data = srcBuffer.read();
if (data.empty()) {
return false;
}
}
if (data.size() > dstSz) {
leftover = data.substr(static_cast<size_t>(dstSz));
read = dstSz;
memcpy(dst, data.c_str(), static_cast<size_t>(dstSz));
} else {
read = data.size();
memcpy(dst, data.c_str(), static_cast<size_t>(read));
}
return true;
},
[&bufferedStream, &output](uint64_t read) -> void {
AWS_UNREFERENCED_PARAM(read);
output << bufferedStream.rdbuf();
});
EXPECT_EQ(output.str(), "jokerskullpanthermona");
}
Loading