Commit d2dc6c09 authored by Kirill Smelkov's avatar Kirill Smelkov

xio: Teach Pipe to support cancellation

Rework Pipe to create (xio.Reader, xio.Writer) instead of (io.Reader,
io.Writer) and teach xio.Pipe{Reader,Writer} to accept ctx argument and
handle cancellation.

I need this to support sysread(/head/watch) cancellation in WCFS
filesystem [1,2,3]. See also [4].

[1] wendelin.core@b17aeb8c
[2] wendelin.core@f05271b1
[3] wendelin.core@5ba816da
[4] https://github.com/golang/go/issues/20280
parent 9db4dfac
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE-go file. // license that can be found in the LICENSE-go file.
// Pipe adapter to connect code expecting an io.Reader // Pipe adapter to connect code expecting an xio.Reader
// with code expecting an io.Writer. // with code expecting an xio.Writer.
package xio package xio
import ( import (
"context"
"io" "io"
"sync" "sync"
) )
...@@ -44,10 +45,12 @@ type pipe struct { ...@@ -44,10 +45,12 @@ type pipe struct {
werr onceError werr onceError
} }
func (p *pipe) Read(b []byte) (n int, err error) { func (p *pipe) Read(ctx context.Context, b []byte) (n int, err error) {
select { select {
case <-p.done: case <-p.done:
return 0, p.readCloseError() return 0, p.readCloseError()
case <-ctx.Done():
return 0, ctx.Err()
default: default:
} }
...@@ -58,6 +61,8 @@ func (p *pipe) Read(b []byte) (n int, err error) { ...@@ -58,6 +61,8 @@ func (p *pipe) Read(b []byte) (n int, err error) {
return nr, nil return nr, nil
case <-p.done: case <-p.done:
return 0, p.readCloseError() return 0, p.readCloseError()
case <-ctx.Done():
return 0, ctx.Err()
} }
} }
...@@ -78,10 +83,12 @@ func (p *pipe) CloseRead(err error) error { ...@@ -78,10 +83,12 @@ func (p *pipe) CloseRead(err error) error {
return nil return nil
} }
func (p *pipe) Write(b []byte) (n int, err error) { func (p *pipe) Write(ctx context.Context, b []byte) (n int, err error) {
select { select {
case <-p.done: case <-p.done:
return 0, p.writeCloseError() return 0, p.writeCloseError()
case <-ctx.Done():
return 0, ctx.Err()
default: default:
p.wrMu.Lock() p.wrMu.Lock()
defer p.wrMu.Unlock() defer p.wrMu.Unlock()
...@@ -95,6 +102,8 @@ func (p *pipe) Write(b []byte) (n int, err error) { ...@@ -95,6 +102,8 @@ func (p *pipe) Write(b []byte) (n int, err error) {
n += nw n += nw
case <-p.done: case <-p.done:
return n, p.writeCloseError() return n, p.writeCloseError()
case <-ctx.Done():
return n, ctx.Err()
} }
} }
return n, nil return n, nil
...@@ -118,17 +127,19 @@ func (p *pipe) CloseWrite(err error) error { ...@@ -118,17 +127,19 @@ func (p *pipe) CloseWrite(err error) error {
} }
// A PipeReader is the read half of a pipe. // A PipeReader is the read half of a pipe.
//
// It is similar to io.PipeReader, but additionally provides cancellation support for Read.
type PipeReader struct { type PipeReader struct {
p *pipe p *pipe
} }
// Read implements the standard Read interface: // Read implements xio.Reader interface:
// it reads data from the pipe, blocking until a writer // it reads data from the pipe, blocking until a writer
// arrives or the write end is closed. // arrives or the write end is closed.
// If the write end is closed with an error, that error is // If the write end is closed with an error, that error is
// returned as err; otherwise err is EOF. // returned as err; otherwise err is EOF.
func (r *PipeReader) Read(data []byte) (n int, err error) { func (r *PipeReader) Read(ctx context.Context, data []byte) (n int, err error) {
return r.p.Read(data) return r.p.Read(ctx, data)
} }
// Close closes the reader; subsequent writes to the // Close closes the reader; subsequent writes to the
...@@ -147,17 +158,19 @@ func (r *PipeReader) CloseWithError(err error) error { ...@@ -147,17 +158,19 @@ func (r *PipeReader) CloseWithError(err error) error {
} }
// A PipeWriter is the write half of a pipe. // A PipeWriter is the write half of a pipe.
//
// It is similar to io.PipeWriter, but additionally provides cancellation support for Write.
type PipeWriter struct { type PipeWriter struct {
p *pipe p *pipe
} }
// Write implements the standard Write interface: // Write implements xio.Writer interface:
// it writes data to the pipe, blocking until one or more readers // it writes data to the pipe, blocking until one or more readers
// have consumed all the data or the read end is closed. // have consumed all the data or the read end is closed.
// If the read end is closed with an error, that err is // If the read end is closed with an error, that err is
// returned as err; otherwise err is io.ErrClosedPipe. // returned as err; otherwise err is io.ErrClosedPipe.
func (w *PipeWriter) Write(data []byte) (n int, err error) { func (w *PipeWriter) Write(ctx context.Context, data []byte) (n int, err error) {
return w.p.Write(data) return w.p.Write(ctx, data)
} }
// Close closes the writer; subsequent reads from the // Close closes the writer; subsequent reads from the
...@@ -177,8 +190,8 @@ func (w *PipeWriter) CloseWithError(err error) error { ...@@ -177,8 +190,8 @@ func (w *PipeWriter) CloseWithError(err error) error {
} }
// Pipe creates a synchronous in-memory pipe. // Pipe creates a synchronous in-memory pipe.
// It can be used to connect code expecting an io.Reader // It can be used to connect code expecting a xio.Reader
// with code expecting an io.Writer. // with code expecting a xio.Writer.
// //
// Reads and Writes on the pipe are matched one to one // Reads and Writes on the pipe are matched one to one
// except when multiple Reads are needed to consume a single Write. // except when multiple Reads are needed to consume a single Write.
...@@ -191,6 +204,9 @@ func (w *PipeWriter) CloseWithError(err error) error { ...@@ -191,6 +204,9 @@ func (w *PipeWriter) CloseWithError(err error) error {
// It is safe to call Read and Write in parallel with each other or with Close. // It is safe to call Read and Write in parallel with each other or with Close.
// Parallel calls to Read and parallel calls to Write are also safe: // Parallel calls to Read and parallel calls to Write are also safe:
// the individual calls will be gated sequentially. // the individual calls will be gated sequentially.
//
// Pipe is similar to io.Pipe but additionally provides cancellation support
// for Read and Write.
func Pipe() (*PipeReader, *PipeWriter) { func Pipe() (*PipeReader, *PipeWriter) {
p := &pipe{ p := &pipe{
wrCh: make(chan []byte), wrCh: make(chan []byte),
......
...@@ -6,6 +6,7 @@ package xio_test ...@@ -6,6 +6,7 @@ package xio_test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
. "lab.nexedi.com/kirr/go123/xio" . "lab.nexedi.com/kirr/go123/xio"
...@@ -15,8 +16,10 @@ import ( ...@@ -15,8 +16,10 @@ import (
"time" "time"
) )
func checkWrite(t *testing.T, w io.Writer, data []byte, c chan int) { var bg = context.Background()
n, err := w.Write(data)
func checkWrite(t *testing.T, w Writer, data []byte, c chan int) {
n, err := w.Write(bg, data)
if err != nil { if err != nil {
t.Errorf("write: %v", err) t.Errorf("write: %v", err)
} }
...@@ -32,7 +35,7 @@ func TestPipe1(t *testing.T) { ...@@ -32,7 +35,7 @@ func TestPipe1(t *testing.T) {
r, w := Pipe() r, w := Pipe()
var buf = make([]byte, 64) var buf = make([]byte, 64)
go checkWrite(t, w, []byte("hello, world"), c) go checkWrite(t, w, []byte("hello, world"), c)
n, err := r.Read(buf) n, err := r.Read(bg, buf)
if err != nil { if err != nil {
t.Errorf("read: %v", err) t.Errorf("read: %v", err)
} else if n != 12 || string(buf[0:12]) != "hello, world" { } else if n != 12 || string(buf[0:12]) != "hello, world" {
...@@ -43,10 +46,10 @@ func TestPipe1(t *testing.T) { ...@@ -43,10 +46,10 @@ func TestPipe1(t *testing.T) {
w.Close() w.Close()
} }
func reader(t *testing.T, r io.Reader, c chan int) { func reader(t *testing.T, r Reader, c chan int) {
var buf = make([]byte, 64) var buf = make([]byte, 64)
for { for {
n, err := r.Read(buf) n, err := r.Read(bg, buf)
if err == io.EOF { if err == io.EOF {
c <- 0 c <- 0
break break
...@@ -66,7 +69,7 @@ func TestPipe2(t *testing.T) { ...@@ -66,7 +69,7 @@ func TestPipe2(t *testing.T) {
var buf = make([]byte, 64) var buf = make([]byte, 64)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
p := buf[0 : 5+i*10] p := buf[0 : 5+i*10]
n, err := w.Write(p) n, err := w.Write(bg, p)
if n != len(p) { if n != len(p) {
t.Errorf("wrote %d, got %d", len(p), n) t.Errorf("wrote %d, got %d", len(p), n)
} }
...@@ -104,11 +107,11 @@ func TestPipe3(t *testing.T) { ...@@ -104,11 +107,11 @@ func TestPipe3(t *testing.T) {
for i := 0; i < len(wdat); i++ { for i := 0; i < len(wdat); i++ {
wdat[i] = byte(i) wdat[i] = byte(i)
} }
go writer(w, wdat, c) go writer(BindCtxWC(w, bg), wdat, c)
var rdat = make([]byte, 1024) var rdat = make([]byte, 1024)
tot := 0 tot := 0
for n := 1; n <= 256; n *= 2 { for n := 1; n <= 256; n *= 2 {
nn, err := r.Read(rdat[tot : tot+n]) nn, err := r.Read(bg, rdat[tot : tot+n])
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
t.Fatalf("read: %v", err) t.Fatalf("read: %v", err)
} }
...@@ -192,7 +195,7 @@ func TestPipeReadClose(t *testing.T) { ...@@ -192,7 +195,7 @@ func TestPipeReadClose(t *testing.T) {
delayClose(t, w, c, tt) delayClose(t, w, c, tt)
} }
var buf = make([]byte, 64) var buf = make([]byte, 64)
n, err := r.Read(buf) n, err := r.Read(bg, buf)
<-c <-c
want := tt.err want := tt.err
if want == nil { if want == nil {
...@@ -215,7 +218,7 @@ func TestPipeReadClose2(t *testing.T) { ...@@ -215,7 +218,7 @@ func TestPipeReadClose2(t *testing.T) {
c := make(chan int, 1) c := make(chan int, 1)
r, _ := Pipe() r, _ := Pipe()
go delayClose(t, r, c, pipeTest{}) go delayClose(t, r, c, pipeTest{})
n, err := r.Read(make([]byte, 64)) n, err := r.Read(bg, make([]byte, 64))
<-c <-c
if n != 0 || err != io.ErrClosedPipe { if n != 0 || err != io.ErrClosedPipe {
t.Errorf("read from closed pipe: %v, %v want %v, %v", n, err, 0, io.ErrClosedPipe) t.Errorf("read from closed pipe: %v, %v want %v, %v", n, err, 0, io.ErrClosedPipe)
...@@ -233,7 +236,7 @@ func TestPipeWriteClose(t *testing.T) { ...@@ -233,7 +236,7 @@ func TestPipeWriteClose(t *testing.T) {
} else { } else {
delayClose(t, r, c, tt) delayClose(t, r, c, tt)
} }
n, err := io.WriteString(w, "hello, world") n, err := io.WriteString(BindCtxW(w, bg), "hello, world")
<-c <-c
expect := tt.err expect := tt.err
if expect == nil { if expect == nil {
...@@ -256,7 +259,7 @@ func TestPipeWriteClose2(t *testing.T) { ...@@ -256,7 +259,7 @@ func TestPipeWriteClose2(t *testing.T) {
c := make(chan int, 1) c := make(chan int, 1)
_, w := Pipe() _, w := Pipe()
go delayClose(t, w, c, pipeTest{}) go delayClose(t, w, c, pipeTest{})
n, err := w.Write(make([]byte, 64)) n, err := w.Write(bg, make([]byte, 64))
<-c <-c
if n != 0 || err != io.ErrClosedPipe { if n != 0 || err != io.ErrClosedPipe {
t.Errorf("write to closed pipe: %v, %v want %v, %v", n, err, 0, io.ErrClosedPipe) t.Errorf("write to closed pipe: %v, %v want %v, %v", n, err, 0, io.ErrClosedPipe)
...@@ -266,22 +269,22 @@ func TestPipeWriteClose2(t *testing.T) { ...@@ -266,22 +269,22 @@ func TestPipeWriteClose2(t *testing.T) {
func TestWriteEmpty(t *testing.T) { func TestWriteEmpty(t *testing.T) {
r, w := Pipe() r, w := Pipe()
go func() { go func() {
w.Write([]byte{}) w.Write(bg, []byte{})
w.Close() w.Close()
}() }()
var b [2]byte var b [2]byte
io.ReadFull(r, b[0:2]) io.ReadFull(BindCtxR(r, bg), b[0:2])
r.Close() r.Close()
} }
func TestWriteNil(t *testing.T) { func TestWriteNil(t *testing.T) {
r, w := Pipe() r, w := Pipe()
go func() { go func() {
w.Write(nil) w.Write(bg, nil)
w.Close() w.Close()
}() }()
var b [2]byte var b [2]byte
io.ReadFull(r, b[0:2]) io.ReadFull(BindCtxR(r, bg), b[0:2])
r.Close() r.Close()
} }
...@@ -291,18 +294,18 @@ func TestWriteAfterWriterClose(t *testing.T) { ...@@ -291,18 +294,18 @@ func TestWriteAfterWriterClose(t *testing.T) {
done := make(chan bool) done := make(chan bool)
var writeErr error var writeErr error
go func() { go func() {
_, err := w.Write([]byte("hello")) _, err := w.Write(bg, []byte("hello"))
if err != nil { if err != nil {
t.Errorf("got error: %q; expected none", err) t.Errorf("got error: %q; expected none", err)
} }
w.Close() w.Close()
_, writeErr = w.Write([]byte("world")) _, writeErr = w.Write(bg, []byte("world"))
done <- true done <- true
}() }()
buf := make([]byte, 100) buf := make([]byte, 100)
var result string var result string
n, err := io.ReadFull(r, buf) n, err := io.ReadFull(BindCtxR(r, bg), buf)
if err != nil && err != io.ErrUnexpectedEOF { if err != nil && err != io.ErrUnexpectedEOF {
t.Fatalf("got: %q; want: %q", err, io.ErrUnexpectedEOF) t.Fatalf("got: %q; want: %q", err, io.ErrUnexpectedEOF)
} }
...@@ -323,21 +326,21 @@ func TestPipeCloseError(t *testing.T) { ...@@ -323,21 +326,21 @@ func TestPipeCloseError(t *testing.T) {
r, w := Pipe() r, w := Pipe()
r.CloseWithError(testError1{}) r.CloseWithError(testError1{})
if _, err := w.Write(nil); err != (testError1{}) { if _, err := w.Write(bg, nil); err != (testError1{}) {
t.Errorf("Write error: got %T, want testError1", err) t.Errorf("Write error: got %T, want testError1", err)
} }
r.CloseWithError(testError2{}) r.CloseWithError(testError2{})
if _, err := w.Write(nil); err != (testError1{}) { if _, err := w.Write(bg, nil); err != (testError1{}) {
t.Errorf("Write error: got %T, want testError1", err) t.Errorf("Write error: got %T, want testError1", err)
} }
r, w = Pipe() r, w = Pipe()
w.CloseWithError(testError1{}) w.CloseWithError(testError1{})
if _, err := r.Read(nil); err != (testError1{}) { if _, err := r.Read(bg, nil); err != (testError1{}) {
t.Errorf("Read error: got %T, want testError1", err) t.Errorf("Read error: got %T, want testError1", err)
} }
w.CloseWithError(testError2{}) w.CloseWithError(testError2{})
if _, err := r.Read(nil); err != (testError1{}) { if _, err := r.Read(bg, nil); err != (testError1{}) {
t.Errorf("Read error: got %T, want testError1", err) t.Errorf("Read error: got %T, want testError1", err)
} }
} }
...@@ -355,7 +358,7 @@ func TestPipeConcurrent(t *testing.T) { ...@@ -355,7 +358,7 @@ func TestPipeConcurrent(t *testing.T) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
go func() { go func() {
time.Sleep(time.Millisecond) // Increase probability of race time.Sleep(time.Millisecond) // Increase probability of race
if n, err := w.Write([]byte(input)); n != len(input) || err != nil { if n, err := w.Write(bg, []byte(input)); n != len(input) || err != nil {
t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input)) t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input))
} }
}() }()
...@@ -363,7 +366,7 @@ func TestPipeConcurrent(t *testing.T) { ...@@ -363,7 +366,7 @@ func TestPipeConcurrent(t *testing.T) {
buf := make([]byte, count*len(input)) buf := make([]byte, count*len(input))
for i := 0; i < len(buf); i += readSize { for i := 0; i < len(buf); i += readSize {
if n, err := r.Read(buf[i : i+readSize]); n != readSize || err != nil { if n, err := r.Read(bg, buf[i : i+readSize]); n != readSize || err != nil {
t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize) t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize)
} }
} }
...@@ -385,7 +388,7 @@ func TestPipeConcurrent(t *testing.T) { ...@@ -385,7 +388,7 @@ func TestPipeConcurrent(t *testing.T) {
go func() { go func() {
time.Sleep(time.Millisecond) // Increase probability of race time.Sleep(time.Millisecond) // Increase probability of race
buf := make([]byte, readSize) buf := make([]byte, readSize)
if n, err := r.Read(buf); n != readSize || err != nil { if n, err := r.Read(bg, buf); n != readSize || err != nil {
t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize) t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize)
} }
c <- buf c <- buf
...@@ -393,7 +396,7 @@ func TestPipeConcurrent(t *testing.T) { ...@@ -393,7 +396,7 @@ func TestPipeConcurrent(t *testing.T) {
} }
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if n, err := w.Write([]byte(input)); n != len(input) || err != nil { if n, err := w.Write(bg, []byte(input)); n != len(input) || err != nil {
t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input)) t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input))
} }
} }
...@@ -422,3 +425,31 @@ func sortBytesInGroups(b []byte, n int) []byte { ...@@ -422,3 +425,31 @@ func sortBytesInGroups(b []byte, n int) []byte {
sort.Slice(groups, func(i, j int) bool { return bytes.Compare(groups[i], groups[j]) < 0 }) sort.Slice(groups, func(i, j int) bool { return bytes.Compare(groups[i], groups[j]) < 0 })
return bytes.Join(groups, nil) return bytes.Join(groups, nil)
} }
// Verify that .Read and .Write handle cancellation.
func TestPipeCancel(t *testing.T) {
buf := make([]byte, 64)
r, _ := Pipe()
ctx, cancel := context.WithCancel(bg)
go func() {
time.Sleep(1*time.Millisecond)
cancel()
}()
n, err := r.Read(ctx, buf)
if eok := context.Canceled; !(n == 0 && err == eok) {
t.Errorf("read: got (%v, %v) ; want (%v, %v)", n, err, 0, eok)
}
_, w := Pipe()
ctx, cancel = context.WithTimeout(bg, 1*time.Millisecond)
n, err = w.Write(ctx, buf)
if eok := context.DeadlineExceeded; !(n == 0 && err == eok) {
t.Errorf("write: got (%v, %v) ; want (%v, %v)", n, err, 0, eok)
}
}
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
// It is the opposite operation for BindCtx, but for arbitrary io.X // It is the opposite operation for BindCtx, but for arbitrary io.X
// returned xio.X handles context only on best-effort basis. In // returned xio.X handles context only on best-effort basis. In
// particular IO cancellation is not reliably handled for os.File . // particular IO cancellation is not reliably handled for os.File .
// - Pipe amends io.Pipe and creates synchronous in-memory pipe that
// supports IO cancellation.
// //
// Miscellaneous utilities: // Miscellaneous utilities:
// //
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment