Commit c75ffbd9 authored by Han-Wen Nienhuys's avatar Han-Wen Nienhuys

Track handed out buffers explicitly.

This makes sure we never take ownership of read only data from a file
system.
parent 67340a77
...@@ -159,7 +159,7 @@ type PassThroughFile struct { ...@@ -159,7 +159,7 @@ type PassThroughFile struct {
} }
func (self *PassThroughFile) Read(input *fuse.ReadIn, buffers *fuse.BufferPool) ([]byte, fuse.Status) { func (self *PassThroughFile) Read(input *fuse.ReadIn, buffers *fuse.BufferPool) ([]byte, fuse.Status) {
slice := buffers.GetBuffer(input.Size) slice := buffers.AllocBuffer(input.Size)
n, err := self.file.ReadAt(slice, int64(input.Offset)) n, err := self.file.ReadAt(slice, int64(input.Offset))
if err == os.EOF { if err == os.EOF {
......
...@@ -73,7 +73,7 @@ func (me *ZipDirTree) FindDir(name string) *ZipDirTree { ...@@ -73,7 +73,7 @@ func (me *ZipDirTree) FindDir(name string) *ZipDirTree {
type ZipFileFuse struct { type ZipFileFuse struct {
zipReader *zip.Reader zipReader *zip.Reader
tree *ZipDirTree tree *ZipDirTree
fuse.DefaultPathFilesystem fuse.DefaultPathFilesystem
} }
...@@ -204,7 +204,6 @@ func (self *ZipFile) Read(input *fuse.ReadIn, bp *fuse.BufferPool) ([]byte, fuse ...@@ -204,7 +204,6 @@ func (self *ZipFile) Read(input *fuse.ReadIn, bp *fuse.BufferPool) ([]byte, fuse
end = len(self.data) end = len(self.data)
} }
// TODO - robustify bufferpool
return self.data[input.Offset:end], fuse.OK return self.data[input.Offset:end], fuse.OK
} }
......
...@@ -3,6 +3,7 @@ package fuse ...@@ -3,6 +3,7 @@ package fuse
import ( import (
"sync" "sync"
"fmt" "fmt"
"unsafe"
) )
// This implements a pool of buffers that returns slices with capacity // This implements a pool of buffers that returns slices with capacity
...@@ -13,6 +14,9 @@ type BufferPool struct { ...@@ -13,6 +14,9 @@ type BufferPool struct {
// For each exponent a list of slice pointers. // For each exponent a list of slice pointers.
buffersByExponent [][][]byte buffersByExponent [][][]byte
// start of slice -> exponent.
outstandingBuffers map[uintptr]uint
} }
// Returns the smallest E such that 2^E >= Z. // Returns the smallest E such that 2^E >= Z.
...@@ -33,6 +37,7 @@ func IntToExponent(z int) uint { ...@@ -33,6 +37,7 @@ func IntToExponent(z int) uint {
func NewBufferPool() *BufferPool { func NewBufferPool() *BufferPool {
bp := new(BufferPool) bp := new(BufferPool)
bp.buffersByExponent = make([][][]byte, 0, 8) bp.buffersByExponent = make([][][]byte, 0, 8)
bp.outstandingBuffers = make(map[uintptr]uint)
return bp return bp
} }
...@@ -44,12 +49,7 @@ func (self *BufferPool) String() string { ...@@ -44,12 +49,7 @@ func (self *BufferPool) String() string {
return s return s
} }
func (self *BufferPool) getBuffer(exponent uint) []byte {
func (self *BufferPool) getBuffer(sz int) []byte {
exponent := int(IntToExponent(sz) - IntToExponent(PAGESIZE))
self.lock.Lock()
defer self.lock.Unlock()
if len(self.buffersByExponent) <= int(exponent) { if len(self.buffersByExponent) <= int(exponent) {
return nil return nil
} }
...@@ -60,30 +60,10 @@ func (self *BufferPool) getBuffer(sz int) []byte { ...@@ -60,30 +60,10 @@ func (self *BufferPool) getBuffer(sz int) []byte {
result := bufferList[len(bufferList)-1] result := bufferList[len(bufferList)-1]
self.buffersByExponent[exponent] = self.buffersByExponent[exponent][:len(bufferList)-1] self.buffersByExponent[exponent] = self.buffersByExponent[exponent][:len(bufferList)-1]
if cap(result) < sz {
panic("returning incorrect buffer.")
}
return result return result
} }
func (self *BufferPool) addBuffer(slice []byte) { func (self *BufferPool) addBuffer(slice []byte, exp uint) {
if cap(slice)&(PAGESIZE-1) != 0 {
return
}
pages := cap(slice) / PAGESIZE
if pages == 0 {
return
}
exp := IntToExponent(pages)
if (1 << exp) != pages {
return
}
self.lock.Lock()
defer self.lock.Unlock()
for len(self.buffersByExponent) <= int(exp) { for len(self.buffersByExponent) <= int(exp) {
self.buffersByExponent = append(self.buffersByExponent, make([][]byte, 0)) self.buffersByExponent = append(self.buffersByExponent, make([][]byte, 0))
} }
...@@ -91,18 +71,46 @@ func (self *BufferPool) addBuffer(slice []byte) { ...@@ -91,18 +71,46 @@ func (self *BufferPool) addBuffer(slice []byte) {
} }
func (self *BufferPool) GetBuffer(size uint32) []byte { func (self *BufferPool) AllocBuffer(size uint32) []byte {
sz := int(size) sz := int(size)
if sz < PAGESIZE { if sz < PAGESIZE {
sz = PAGESIZE sz = PAGESIZE
} }
rounded := 1 << IntToExponent(sz)
b := self.getBuffer(rounded) exp := IntToExponent(sz)
rounded := 1 << exp
exp -= IntToExponent(PAGESIZE)
self.lock.Lock()
defer self.lock.Unlock()
b := self.getBuffer(exp)
if b != nil { if b != nil {
b = b[:size] b = b[:size]
return b return b
} }
return make([]byte, size, rounded) b = make([]byte, size, rounded)
self.outstandingBuffers[uintptr(unsafe.Pointer(&b[0]))] = exp
return b
}
// Takes back a buffer if it was allocated through AllocBuffer. It is
// not an error to call FreeBuffer() on a slice obtained elsewhere.
func (self *BufferPool) FreeBuffer(slice []byte) {
self.lock.Lock()
defer self.lock.Unlock()
if cap(slice) < PAGESIZE {
return
}
key := uintptr(unsafe.Pointer(&slice[0]))
exp, ok := self.outstandingBuffers[key]
if ok {
self.addBuffer(slice, exp)
self.outstandingBuffers[key] = 0, false
}
} }
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"testing" "testing"
"fmt" "fmt"
) )
var _ = fmt.Println
func TestIntToExponent(t *testing.T) { func TestIntToExponent(t *testing.T) {
e := IntToExponent(1) e := IntToExponent(1)
...@@ -27,53 +28,13 @@ func TestIntToExponent(t *testing.T) { ...@@ -27,53 +28,13 @@ func TestIntToExponent(t *testing.T) {
func TestBufferPool(t *testing.T) { func TestBufferPool(t *testing.T) {
bp := NewBufferPool() bp := NewBufferPool()
b := bp.getBuffer(PAGESIZE - 1) b1 := bp.AllocBuffer(PAGESIZE)
if b != nil { _ = bp.AllocBuffer(2*PAGESIZE)
t.Error("bp 0") bp.FreeBuffer(b1)
}
b = bp.getBuffer(PAGESIZE)
if b != nil {
t.Error("bp 1")
}
s := make([]byte, PAGESIZE-1)
bp.addBuffer(s)
b = bp.getBuffer(PAGESIZE - 1)
if b != nil {
t.Error("bp 3")
}
s = make([]byte, PAGESIZE)
bp.addBuffer(s)
b = bp.getBuffer(PAGESIZE)
if b == nil {
t.Error("not found.")
}
b = bp.getBuffer(PAGESIZE)
if b != nil {
t.Error("should fail.")
}
bp.addBuffer(make([]byte, 3*PAGESIZE)) b1_2 := bp.AllocBuffer(PAGESIZE)
b = bp.getBuffer(2 * PAGESIZE) if &b1[0] != &b1_2[0] {
if b != nil { t.Error("bp 0")
t.Error("should fail.")
}
b = bp.getBuffer(4 * PAGESIZE)
if b != nil {
t.Error("should fail.")
}
bp.addBuffer(make([]byte, 4*PAGESIZE))
fmt.Println(bp)
b = bp.getBuffer(2 * PAGESIZE)
if b != nil {
t.Error("should fail.")
}
b = bp.getBuffer(4 * PAGESIZE)
if b == nil {
t.Error("4*ps should succeed.")
} }
} }
...@@ -110,6 +110,6 @@ func (me *FuseDir) ReadDir(input *ReadIn) (*DirEntryList, Status) { ...@@ -110,6 +110,6 @@ func (me *FuseDir) ReadDir(input *ReadIn) (*DirEntryList, Status) {
} }
func (me *FuseDir) ReleaseDir() { func (me *FuseDir) ReleaseDir() {
close(me.stream) // TODO - should close ?
} }
...@@ -222,7 +222,7 @@ func (self *MountState) syncWrite(packet [][]byte) { ...@@ -222,7 +222,7 @@ func (self *MountState) syncWrite(packet [][]byte) {
self.Error(os.NewError(fmt.Sprintf("writer: Writev %v failed, err: %v", packet, err))) self.Error(os.NewError(fmt.Sprintf("writer: Writev %v failed, err: %v", packet, err)))
} }
for _, v := range packet { for _, v := range packet {
self.buffers.addBuffer(v) self.buffers.FreeBuffer(v)
} }
} }
...@@ -233,7 +233,7 @@ func (self *MountState) syncWrite(packet [][]byte) { ...@@ -233,7 +233,7 @@ func (self *MountState) syncWrite(packet [][]byte) {
func (self *MountState) loop() { func (self *MountState) loop() {
// See fuse_kern_chan_receive() // See fuse_kern_chan_receive()
for { for {
buf := self.buffers.GetBuffer(bufSize) buf := self.buffers.AllocBuffer(bufSize)
n, err := self.mountFile.Read(buf) n, err := self.mountFile.Read(buf)
if err != nil { if err != nil {
errNo := OsErrorToFuseError(err) errNo := OsErrorToFuseError(err)
...@@ -280,7 +280,7 @@ func (self *MountState) handle(in_data []byte) { ...@@ -280,7 +280,7 @@ func (self *MountState) handle(in_data []byte) {
return return
} }
self.Write(dispatch(self, header, r)) self.Write(dispatch(self, header, r))
self.buffers.addBuffer(in_data) self.buffers.FreeBuffer(in_data)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment