Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/20250909150027.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:sparkles: `safeio` Add support for cancelling readers that make blocking kernel reads during copying
27 changes: 27 additions & 0 deletions utils/safeio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ func Cat(ctx context.Context, dst io.Writer, src ...io.Reader) (copied int64, er
return CopyDataWithContext(ctx, NewContextualMultipleReader(ctx, src...), dst)
}

// SafeCopyDataWithContext copies from src to dst similarly to io.Copy but with context control to stop when asked.
// Unlike CopyWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
func SafeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer) (copied int64, err error) {
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.Copy(dst, src) })
}

// SafeCopyNWithContext copies n bytes from src to dst similarly to io.CopyN but with context control to stop when asked.
// Unlike CopyNWithContext it requires a ReadCloser, this allows it to stop even if the system is doing a kernel read.
func SafeCopyNWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, n int64) (copied int64, err error) {
return safeCopyDataWithContext(ctx, src, dst, func(dst io.Writer, src io.ReadCloser) (int64, error) { return io.CopyN(dst, src, n) })
}

func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copyFunc func(io.Writer, io.Reader) (int64, error)) (copied int64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
Expand All @@ -37,8 +49,23 @@ func copyDataWithContext(ctx context.Context, src io.Reader, dst io.Writer, copy
return
}

func safeCopyDataWithContext(ctx context.Context, src io.ReadCloser, dst io.Writer, copyFunc func(io.Writer, io.ReadCloser) (int64, error)) (copied int64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
copied, err = reallySafeCopy(ContextualWriter(ctx, dst), NewContextualReadCloser(ctx, src), copyFunc)
return
}

func safeCopy(w io.Writer, r io.Reader, iocopyFunc func(io.Writer, io.Reader) (int64, error)) (int64, error) {
copied, err := iocopyFunc(w, r)
err = ConvertIOError(err)
return copied, err
}

func reallySafeCopy(w io.Writer, r io.ReadCloser, iocopyFunc func(io.Writer, io.ReadCloser) (int64, error)) (int64, error) {
copied, err := iocopyFunc(w, r)
err = ConvertIOError(err)
return copied, err
}
133 changes: 133 additions & 0 deletions utils/safeio/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package safeio
import (
"bytes"
"context"
"io"
"os"
"testing"
"time"

"github.com/go-faker/faker/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
Expand Down Expand Up @@ -81,6 +85,135 @@ func TestCopyNWithContext(t *testing.T) {
assert.Equal(t, safecast.ToInt64(len(text)-1), n2)
}

func TestSafeCopyDataWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var buf1, buf2 bytes.Buffer
text := faker.Sentence()
n, err := WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)
rc := io.NopCloser(bytes.NewReader(buf1.Bytes())) // make it an io.ReadCloser
n2, err := SafeCopyDataWithContext(context.Background(), rc, &buf2)
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, safecast.ToInt64(len(text)), n2)
assert.Equal(t, text, buf2.String())

ctx, cancel := context.WithCancel(context.Background())
buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)

cancel()
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err = SafeCopyDataWithContext(ctx, rc, &buf2)
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.Zero(t, n2)
assert.Empty(t, buf2.String())

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()
ctx2, unblock := context.WithCancel(context.Background())
done := make(chan struct{})

go func() {
_, errCopy := SafeCopyDataWithContext(ctx2, r, io.Discard)
_ = r.Close()
_ = errCopy
close(done)
}()

time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
unblock()

select {
case <-done:
// Expected case: unblocked
case <-time.After(2 * time.Second):
assert.FailNow(t, "context cancel should have unblocked copy")
}
}

func TestSafeCopyNWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var buf1, buf2 bytes.Buffer
text := faker.Sentence()
n, err := WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)
rc := io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err := SafeCopyNWithContext(context.Background(), rc, &buf2, safecast.ToInt64(len(text)))
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, safecast.ToInt64(len(text)), n2)
assert.Equal(t, text, buf2.String())

ctx, cancel := context.WithCancel(context.Background())

buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
assert.Equal(t, len(text), n)

cancel()
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))
n2, err = SafeCopyNWithContext(ctx, rc, &buf2, safecast.ToInt64(len(text)))
require.Error(t, err)
errortest.AssertError(t, err, commonerrors.ErrCancelled)
assert.Zero(t, n2)
assert.Empty(t, buf2.String())

buf1.Reset()
buf2.Reset()
n, err = WriteString(context.Background(), &buf1, text)
require.NoError(t, err)
require.NotZero(t, n)
rc = io.NopCloser(bytes.NewReader(buf1.Bytes()))

wantN := safecast.ToInt64(len(text) - 1)
n2, err = SafeCopyNWithContext(context.Background(), rc, &buf2, wantN)
require.NoError(t, err)
require.NotZero(t, n2)
assert.Equal(t, wantN, n2)
assert.Equal(t, text[:len(text)-1], buf2.String())

r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()
ctx2, unblock := context.WithCancel(context.Background())
done := make(chan struct{})
var (
copied int64
copyErr error
)

go func() {
copied, copyErr = SafeCopyNWithContext(ctx2, r, io.Discard, 1024) // nothing to read means it blocks
_ = r.Close()
close(done)
}()

time.Sleep(50 * time.Millisecond) // let it enter read(2) https://man7.org/linux/man-pages/man2/read.2.html
unblock()

select {
case <-done:
errortest.AssertError(t, copyErr, commonerrors.ErrCancelled)
assert.Zero(t, copied)
case <-time.After(2 * time.Second):
assert.FailNow(t, "context cancel should have unblocked copy")
}
}

func TestCat(t *testing.T) {
var buf1, buf2, buf3 bytes.Buffer
text1 := faker.Sentence()
Expand Down
4 changes: 4 additions & 0 deletions utils/safeio/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package safeio

import (
"io"
"os"

"github.com/ARM-software/golang-utils/utils/commonerrors"
)
Expand All @@ -16,6 +17,9 @@ func ConvertIOError(err error) (newErr error) {
case commonerrors.Any(newErr, commonerrors.ErrEOF):
case commonerrors.Any(newErr, io.EOF, io.ErrUnexpectedEOF):
newErr = commonerrors.WrapError(commonerrors.ErrEOF, newErr, "")
case commonerrors.Any(newErr, os.ErrClosed):
// cancelling a reader on a copy will cause it to close the file and return os.ErrClosed so map it to cancelled for this package
newErr = commonerrors.WrapError(commonerrors.ErrCancelled, newErr, "")
}
return
}
37 changes: 37 additions & 0 deletions utils/safeio/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"

"github.com/dolmen-go/contextio"
"go.uber.org/atomic"

"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/parallelisation"
Expand Down Expand Up @@ -76,6 +77,42 @@ func NewContextualReader(ctx context.Context, reader io.Reader) io.Reader {
return contextio.NewReader(ctx, reader)
}

type safeReadCloser struct {
reader io.Reader
close parallelisation.CloseFunc
closed *atomic.Bool
}

func (r safeReadCloser) Read(p []byte) (int, error) {
return r.reader.Read(p)
}

func (r safeReadCloser) Close() error {
if r.closed.Swap(true) {
return nil
}

return r.close()
}

// NewContextualReadCloser returns a readcloser which is context aware.
// Context state is checked during the read and close is called if the context is cancelled
// This allows for readers that block on syscalls to be stopped via a context
func NewContextualReadCloser(ctx context.Context, reader io.ReadCloser) io.ReadCloser {
stop := context.AfterFunc(ctx, func() { _ = reader.Close() })

r := safeReadCloser{
reader: contextio.NewReader(ctx, reader),
close: func() error {
_ = stop()
return reader.Close()
},
closed: atomic.NewBool(false),
}

return r
}

func NewContextualMultipleReader(ctx context.Context, reader ...io.Reader) io.Reader {
readers := make([]io.Reader, len(reader))
for i := range reader {
Expand Down
65 changes: 65 additions & 0 deletions utils/safeio/read_closer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package safeio

import (
"context"
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewContextualReadCloser(t *testing.T) {
t.Run("Normal contextual reader blocks even after cancel", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = r.Close(); _ = w.Close() }()

ctx, cancel := context.WithCancel(context.Background())
reader := NewContextualReader(ctx, r)

done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, reader) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
close(done)
}()

// Allow io.Copy to enter kernel read then try to cancel
time.Sleep(50 * time.Millisecond)
cancel()

select {
case <-done:
assert.FailNow(t, "cancelling context shouldn't unblock a blocking Read in io.Copy")
case <-time.After(200 * time.Millisecond):
// Expected case: still blocked
}
})

t.Run("Contextual read closer does not block even on long running copies", func(t *testing.T) {
r, w, err := os.Pipe()
require.NoError(t, err)
defer func() { _ = w.Close() }()

ctx, cancel := context.WithCancel(context.Background())
rc := NewContextualReadCloser(ctx, r)

done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, rc) // will block in read(2) https://man7.org/linux/man-pages/man2/read.2.html
close(done)
}()

time.Sleep(50 * time.Millisecond)
cancel()

select {
case <-done:
// Expected case: successfully unblocked
case <-time.After(2 * time.Second):
assert.FailNow(t, "copy should have been unblocked by context cancel")
}
})
}
Loading