Commit 8b21b5b7 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: use public MockCommunicator, tests pass

parent 46e02209
package packer package packer
import ( import (
"bytes"
"io" "io"
"sync"
) )
// MockCommunicator is a valid Communicator implementation that can be // MockCommunicator is a valid Communicator implementation that can be
// used for tests. // used for tests.
type MockCommunicator struct { type MockCommunicator struct {
Stderr io.Reader StartCalled bool
Stdout io.Reader StartCmd *RemoteCmd
StartStderr string
StartStdout string
StartStdin string
StartExitStatus int
UploadCalled bool
UploadPath string
UploadData string
DownloadCalled bool
DownloadPath string
DownloadData string
} }
func (c *MockCommunicator) Start(rc *RemoteCmd) error { func (c *MockCommunicator) Start(rc *RemoteCmd) error {
c.StartCalled = true
c.StartCmd = rc
go func() { go func() {
rc.Lock() var wg sync.WaitGroup
defer rc.Unlock() if rc.Stdout != nil && c.StartStdout != "" {
wg.Add(1)
go func() {
rc.Stdout.Write([]byte(c.StartStdout))
wg.Done()
}()
}
if rc.Stdout != nil && c.Stdout != nil { if rc.Stderr != nil && c.StartStderr != "" {
io.Copy(rc.Stdout, c.Stdout) wg.Add(1)
go func() {
rc.Stderr.Write([]byte(c.StartStderr))
wg.Done()
}()
} }
if rc.Stderr != nil && c.Stderr != nil { if rc.Stdin != nil {
io.Copy(rc.Stderr, c.Stderr) wg.Add(1)
go func() {
defer wg.Done()
var data bytes.Buffer
io.Copy(&data, rc.Stdin)
c.StartStdin = data.String()
}()
} }
wg.Wait()
rc.SetExited(c.StartExitStatus)
}() }()
return nil return nil
} }
func (c *MockCommunicator) Upload(string, io.Reader) error { func (c *MockCommunicator) Upload(path string, r io.Reader) error {
c.UploadCalled = true
c.UploadPath = path
var data bytes.Buffer
if _, err := io.Copy(&data, r); err != nil {
panic(err)
}
c.UploadData = data.String()
return nil return nil
} }
...@@ -36,6 +82,10 @@ func (c *MockCommunicator) UploadDir(string, string, []string) error { ...@@ -36,6 +82,10 @@ func (c *MockCommunicator) UploadDir(string, string, []string) error {
return nil return nil
} }
func (c *MockCommunicator) Download(string, io.Writer) error { func (c *MockCommunicator) Download(path string, w io.Writer) error {
c.DownloadCalled = true
c.DownloadPath = path
w.Write([]byte(c.DownloadData))
return nil return nil
} }
...@@ -11,14 +11,10 @@ func TestRemoteCmd_StartWithUi(t *testing.T) { ...@@ -11,14 +11,10 @@ func TestRemoteCmd_StartWithUi(t *testing.T) {
data := "hello\nworld\nthere" data := "hello\nworld\nthere"
originalOutput := new(bytes.Buffer) originalOutput := new(bytes.Buffer)
rcOutput := new(bytes.Buffer)
uiOutput := new(bytes.Buffer) uiOutput := new(bytes.Buffer)
rcOutput.WriteString(data)
testComm := &MockCommunicator{
Stdout: rcOutput,
}
testComm := new(MockCommunicator)
testComm.StartStdout = data
testUi := &BasicUi{ testUi := &BasicUi{
Reader: new(bytes.Buffer), Reader: new(bytes.Buffer),
Writer: uiOutput, Writer: uiOutput,
...@@ -29,22 +25,20 @@ func TestRemoteCmd_StartWithUi(t *testing.T) { ...@@ -29,22 +25,20 @@ func TestRemoteCmd_StartWithUi(t *testing.T) {
Stdout: originalOutput, Stdout: originalOutput,
} }
go func() {
time.Sleep(100 * time.Millisecond)
rc.SetExited(0)
}()
err := rc.StartWithUi(testComm, testUi) err := rc.StartWithUi(testComm, testUi)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if uiOutput.String() != strings.TrimSpace(data)+"\n" { rc.Wait()
expected := strings.TrimSpace(data)
if uiOutput.String() != expected+"\n" {
t.Fatalf("bad output: '%s'", uiOutput.String()) t.Fatalf("bad output: '%s'", uiOutput.String())
} }
if originalOutput.String() != data { if originalOutput.String() != expected {
t.Fatalf("original is bad: '%s'", originalOutput.String()) t.Fatalf("bad: %#v", originalOutput.String())
} }
} }
......
...@@ -123,6 +123,10 @@ func (c *communicator) Upload(path string, r io.Reader) (err error) { ...@@ -123,6 +123,10 @@ func (c *communicator) Upload(path string, r io.Reader) (err error) {
return return
} }
func (c *communicator) UploadDir(dst string, src string, exclude []string) error {
return nil
}
func (c *communicator) Download(path string, w io.Writer) (err error) { func (c *communicator) Download(path string, w io.Writer) (err error) {
// We need to create a server that can proxy that data downloaded // We need to create a server that can proxy that data downloaded
// into the writer because we can't gob encode a writer directly. // into the writer because we can't gob encode a writer directly.
......
...@@ -2,52 +2,15 @@ package rpc ...@@ -2,52 +2,15 @@ package rpc
import ( import (
"bufio" "bufio"
"cgl.tideland.biz/asserts"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"io" "io"
"net/rpc" "net/rpc"
"testing" "testing"
"time"
) )
type testCommunicator struct {
startCalled bool
startCmd *packer.RemoteCmd
uploadCalled bool
uploadPath string
uploadData string
downloadCalled bool
downloadPath string
}
func (t *testCommunicator) Start(cmd *packer.RemoteCmd) error {
t.startCalled = true
t.startCmd = cmd
return nil
}
func (t *testCommunicator) Upload(path string, reader io.Reader) (err error) {
t.uploadCalled = true
t.uploadPath = path
t.uploadData, err = bufio.NewReader(reader).ReadString('\n')
return
}
func (t *testCommunicator) Download(path string, writer io.Writer) error {
t.downloadCalled = true
t.downloadPath = path
writer.Write([]byte("download\n"))
return nil
}
func TestCommunicatorRPC(t *testing.T) { func TestCommunicatorRPC(t *testing.T) {
assert := asserts.NewTestingAsserts(t, true)
// Create the interface to test // Create the interface to test
c := new(testCommunicator) c := new(packer.MockCommunicator)
// Start the server // Start the server
server := rpc.NewServer() server := rpc.NewServer()
...@@ -56,7 +19,9 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -56,7 +19,9 @@ func TestCommunicatorRPC(t *testing.T) {
// Create the client over RPC and run some methods to verify it works // Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address) client, err := rpc.Dial("tcp", address)
assert.Nil(err, "should be able to connect") if err != nil {
t.Fatalf("err: %s", err)
}
remote := Communicator(client) remote := Communicator(client)
// The remote command we'll use // The remote command we'll use
...@@ -70,56 +35,74 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -70,56 +35,74 @@ func TestCommunicatorRPC(t *testing.T) {
cmd.Stdout = stdout_w cmd.Stdout = stdout_w
cmd.Stderr = stderr_w cmd.Stderr = stderr_w
// Send some data on stdout and stderr from the mock
c.StartStdout = "outfoo\n"
c.StartStderr = "errfoo\n"
c.StartExitStatus = 42
// Test Start // Test Start
err = remote.Start(&cmd) err = remote.Start(&cmd)
assert.Nil(err, "should not have an error") if err != nil {
t.Fatalf("err: %s", err)
}
// Test that we can read from stdout // Test that we can read from stdout
c.startCmd.Stdout.Write([]byte("outfoo\n"))
bufOut := bufio.NewReader(stdout_r) bufOut := bufio.NewReader(stdout_r)
data, err := bufOut.ReadString('\n') data, err := bufOut.ReadString('\n')
assert.Nil(err, "should have no problem reading stdout") if err != nil {
assert.Equal(data, "outfoo\n", "should be correct stdout") t.Fatalf("err: %s", err)
}
if data != "outfoo\n" {
t.Fatalf("bad data: %s", data)
}
// Test that we can read from stderr // Test that we can read from stderr
c.startCmd.Stderr.Write([]byte("errfoo\n"))
bufErr := bufio.NewReader(stderr_r) bufErr := bufio.NewReader(stderr_r)
data, err = bufErr.ReadString('\n') data, err = bufErr.ReadString('\n')
assert.Nil(err, "should have no problem reading stderr") if err != nil {
assert.Equal(data, "errfoo\n", "should be correct stderr") t.Fatalf("err: %s", err)
}
if data != "errfoo\n" {
t.Fatalf("bad data: %s", data)
}
// Test that we can write to stdin // Test that we can write to stdin
stdin_w.Write([]byte("infoo\n")) stdin_w.Write([]byte("info\n"))
bufIn := bufio.NewReader(c.startCmd.Stdin) stdin_w.Close()
data, err = bufIn.ReadString('\n') cmd.Wait()
assert.Nil(err, "should have no problem reading stdin") if c.StartStdin != "info\n" {
assert.Equal(data, "infoo\n", "should be correct stdin") t.Fatalf("bad data: %s", data)
}
// Test that we can get the exit status properly // Test that we can get the exit status properly
c.startCmd.SetExited(42) if cmd.ExitStatus != 42 {
t.Fatalf("bad exit: %d", cmd.ExitStatus)
for i := 0; i < 5; i++ {
cmd.Lock()
exited := cmd.Exited
cmd.Unlock()
if exited {
assert.Equal(cmd.ExitStatus, 42, "should have proper exit status")
break
}
time.Sleep(50 * time.Millisecond)
} }
assert.True(cmd.Exited, "should have exited")
// Test that we can upload things // Test that we can upload things
uploadR, uploadW := io.Pipe() uploadR, uploadW := io.Pipe()
go uploadW.Write([]byte("uploadfoo\n")) go func() {
defer uploadW.Close()
uploadW.Write([]byte("uploadfoo\n"))
}()
err = remote.Upload("foo", uploadR) err = remote.Upload("foo", uploadR)
assert.Nil(err, "should not error") if err != nil {
assert.True(c.uploadCalled, "should be called") t.Fatalf("err: %s", err)
assert.Equal(c.uploadPath, "foo", "should be correct path") }
assert.Equal(c.uploadData, "uploadfoo\n", "should have the proper data")
if !c.UploadCalled {
t.Fatal("should have uploaded")
}
if c.UploadPath != "foo" {
t.Fatalf("path: %s", c.UploadPath)
}
if c.UploadData != "uploadfoo\n" {
t.Fatalf("bad: %s", c.UploadData)
}
// Test that we can download things // Test that we can download things
downloadR, downloadW := io.Pipe() downloadR, downloadW := io.Pipe()
...@@ -133,21 +116,34 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -133,21 +116,34 @@ func TestCommunicatorRPC(t *testing.T) {
downloadDone <- true downloadDone <- true
}() }()
c.DownloadData = "download\n"
err = remote.Download("bar", downloadW) err = remote.Download("bar", downloadW)
assert.Nil(err, "should not error") if err != nil {
assert.True(c.downloadCalled, "should have called download") t.Fatalf("err: %s", err)
assert.Equal(c.downloadPath, "bar", "should have correct download path") }
if !c.DownloadCalled {
t.Fatal("download should be called")
}
if c.DownloadPath != "bar" {
t.Fatalf("bad: %s", c.DownloadPath)
}
<-downloadDone <-downloadDone
assert.Nil(downloadErr, "should not error reading download data") if downloadErr != nil {
assert.Equal(downloadData, "download\n", "should have the proper data") t.Fatalf("err: %s", downloadErr)
}
if downloadData != "download\n" {
t.Fatalf("bad: %s", downloadData)
}
} }
func TestCommunicator_ImplementsCommunicator(t *testing.T) { func TestCommunicator_ImplementsCommunicator(t *testing.T) {
assert := asserts.NewTestingAsserts(t, true) var raw interface{}
raw = Communicator(nil)
var r packer.Communicator if _, ok := raw.(packer.Communicator); !ok {
c := Communicator(nil) t.Fatal("should be a Communicator")
}
assert.Implementor(c, &r, "should be a Communicator")
} }
...@@ -52,7 +52,7 @@ func TestProvisionerRPC(t *testing.T) { ...@@ -52,7 +52,7 @@ func TestProvisionerRPC(t *testing.T) {
// Test Provision // Test Provision
ui := &testUi{} ui := &testUi{}
comm := &testCommunicator{} comm := new(packer.MockCommunicator)
pClient.Provision(ui, comm) pClient.Provision(ui, comm)
assert.True(p.provCalled, "provision should be called") assert.True(p.provCalled, "provision should be called")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment