Commit ea4171f1 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/rpc: Hook up the new communicator interface

parent 532faec4
...@@ -21,17 +21,11 @@ type CommunicatorServer struct { ...@@ -21,17 +21,11 @@ type CommunicatorServer struct {
c packer.Communicator c packer.Communicator
} }
// RemoteCommandServer wraps a packer.RemoteCommand struct and makes it type CommunicatorStartArgs struct {
// exportable as part of a Golang RPC server. Command string
type RemoteCommandServer struct {
rc *packer.RemoteCommand
}
type CommunicatorStartResponse struct {
StdinAddress string StdinAddress string
StdoutAddress string StdoutAddress string
StderrAddress string StderrAddress string
RemoteCommandAddress string
} }
type CommunicatorDownloadArgs struct { type CommunicatorDownloadArgs struct {
...@@ -48,52 +42,29 @@ func Communicator(client *rpc.Client) *communicator { ...@@ -48,52 +42,29 @@ func Communicator(client *rpc.Client) *communicator {
return &communicator{client} return &communicator{client}
} }
func (c *communicator) Start(cmd string) (rc *packer.RemoteCommand, err error) { func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
var response CommunicatorStartResponse var args CommunicatorStartArgs
err = c.client.Call("Communicator.Start", &cmd, &response) args.Command = cmd.Command
if err != nil {
return
}
// Connect to the three streams that will handle stdin, stdout,
// and stderr and get net.Conns for them.
stdinC, err := net.Dial("tcp", response.StdinAddress)
if err != nil {
return
}
stdoutC, err := net.Dial("tcp", response.StdoutAddress) if cmd.Stdin != nil {
if err != nil { stdinL := netListenerInRange(portRangeMin, portRangeMax)
return args.StdinAddress = stdinL.Addr().String()
} go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin)
stderrC, err := net.Dial("tcp", response.StderrAddress)
if err != nil {
return
} }
// Connect to the RPC server for the remote command if cmd.Stdout != nil {
client, err := rpc.Dial("tcp", response.RemoteCommandAddress) stdoutL := netListenerInRange(portRangeMin, portRangeMax)
if err != nil { args.StdoutAddress = stdoutL.Addr().String()
return go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil)
} }
// Build the response object using the streams we created if cmd.Stderr != nil {
rc = &packer.RemoteCommand{ stderrL := netListenerInRange(portRangeMin, portRangeMax)
Stdin: stdinC, args.StderrAddress = stderrL.Addr().String()
Stdout: stdoutC, go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil)
Stderr: stderrC,
Exited: false,
ExitStatus: -1,
} }
// In a goroutine, we wait for the process to exit, then we set err = c.client.Call("Communicator.Start", &args, new(interface{}))
// that it has exited.
go func() {
client.Call("RemoteCommand.Wait", new(interface{}), &rc.ExitStatus)
rc.Exited = true
}()
return return
} }
...@@ -145,41 +116,41 @@ func (c *communicator) Download(path string, w io.Writer) (err error) { ...@@ -145,41 +116,41 @@ func (c *communicator) Download(path string, w io.Writer) (err error) {
return return
} }
func (c *CommunicatorServer) Start(cmd *string, reply *CommunicatorStartResponse) (err error) { func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) {
// Start executing the command. // Build the RemoteCmd on this side so that it all pipes over
command, err := c.c.Start(*cmd) // to the remote side.
var cmd packer.RemoteCmd
cmd.Command = args.Command
if args.StdinAddress != "" {
stdinC, err := net.Dial("tcp", args.StdinAddress)
if err != nil { if err != nil {
return return err
} }
// If we didn't get a proper command... that's not right. cmd.Stdin = stdinC
if command == nil {
return errors.New("communicator returned nil remote command")
} }
// Next, we need to take the stdin/stdout and start a listener if args.StdoutAddress != "" {
// for each because the client will connect to us via TCP and use stdoutC, err := net.Dial("tcp", args.StdoutAddress)
// that connection as the io.Reader or io.Writer. These exist for if err != nil {
// only a single connection that is persistent. return err
stdinL := netListenerInRange(portRangeMin, portRangeMax) }
stdoutL := netListenerInRange(portRangeMin, portRangeMax)
stderrL := netListenerInRange(portRangeMin, portRangeMax) cmd.Stdout = stdoutC
go serveSingleCopy("stdin", stdinL, command.Stdin, nil) }
go serveSingleCopy("stdout", stdoutL, nil, command.Stdout)
go serveSingleCopy("stderr", stderrL, nil, command.Stderr)
// For the exit status, we use a simple RPC Server that serves if args.StderrAddress != "" {
// some of the RemoteComand methods. stderrC, err := net.Dial("tcp", args.StderrAddress)
server := rpc.NewServer() if err != nil {
server.RegisterName("RemoteCommand", &RemoteCommandServer{command}) return err
}
*reply = CommunicatorStartResponse{ cmd.Stderr = stderrC
stdinL.Addr().String(),
stdoutL.Addr().String(),
stderrL.Addr().String(),
serveSingleConn(server),
} }
// Start the actual command
err = c.c.Start(&cmd)
return return
} }
...@@ -207,12 +178,6 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int ...@@ -207,12 +178,6 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int
return return
} }
func (rc *RemoteCommandServer) Wait(args *interface{}, reply *int) error {
rc.rc.Wait()
*reply = rc.rc.ExitStatus
return nil
}
func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) { func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) {
defer l.Close() defer l.Close()
......
...@@ -11,13 +11,7 @@ import ( ...@@ -11,13 +11,7 @@ import (
type testCommunicator struct { type testCommunicator struct {
startCalled bool startCalled bool
startCmd string startCmd *packer.RemoteCmd
startIn *io.PipeReader
startOut *io.PipeWriter
startErr *io.PipeWriter
startExited *bool
startExitStatus *int
uploadCalled bool uploadCalled bool
uploadPath string uploadPath string
...@@ -27,29 +21,10 @@ type testCommunicator struct { ...@@ -27,29 +21,10 @@ type testCommunicator struct {
downloadPath string downloadPath string
} }
func (t *testCommunicator) Start(cmd string) (*packer.RemoteCommand, error) { func (t *testCommunicator) Start(cmd *packer.RemoteCmd) error {
t.startCalled = true t.startCalled = true
t.startCmd = cmd t.startCmd = cmd
return nil
var stdin *io.PipeWriter
var stdout, stderr *io.PipeReader
t.startIn, stdin = io.Pipe()
stdout, t.startOut = io.Pipe()
stderr, t.startErr = io.Pipe()
rc := &packer.RemoteCommand{
Stdin: stdin,
Stdout: stdout,
Stderr: stderr,
Exited: false,
ExitStatus: 0,
}
t.startExited = &rc.Exited
t.startExitStatus = &rc.ExitStatus
return rc, nil
} }
func (t *testCommunicator) Upload(path string, reader io.Reader) (err error) { func (t *testCommunicator) Upload(path string, reader io.Reader) (err error) {
...@@ -81,38 +56,46 @@ func TestCommunicatorRPC(t *testing.T) { ...@@ -81,38 +56,46 @@ func TestCommunicatorRPC(t *testing.T) {
// Create the client over RPC and run some methods to verify it works // Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address) client, err := rpc.Dial("tcp", address)
assert.Nil(err, "should be able to connect") assert.Nil(err, "should be able to connect")
remote := Communicator(client)
// The remote command we'll use
stdin_r, stdin_w := io.Pipe()
stdout_r, stdout_w := io.Pipe()
stderr_r, stderr_w := io.Pipe()
var cmd packer.RemoteCmd
cmd.Command = "foo"
cmd.Stdin = stdin_r
cmd.Stdout = stdout_w
cmd.Stderr = stderr_w
// Test Start // Test Start
remote := Communicator(client) err = remote.Start(&cmd)
rc, err := remote.Start("foo")
assert.Nil(err, "should not have an error") assert.Nil(err, "should not have an error")
// Test that we can read from stdout // Test that we can read from stdout
bufOut := bufio.NewReader(rc.Stdout) c.startCmd.Stdout.Write([]byte("outfoo\n"))
c.startOut.Write([]byte("outfoo\n")) bufOut := bufio.NewReader(stdout_r)
data, err := bufOut.ReadString('\n') data, err := bufOut.ReadString('\n')
assert.Nil(err, "should have no problem reading stdout") assert.Nil(err, "should have no problem reading stdout")
assert.Equal(data, "outfoo\n", "should be correct stdout") assert.Equal(data, "outfoo\n", "should be correct stdout")
// Test that we can read from stderr // Test that we can read from stderr
bufErr := bufio.NewReader(rc.Stderr) c.startCmd.Stderr.Write([]byte("errfoo\n"))
c.startErr.Write([]byte("errfoo\n")) bufErr := bufio.NewReader(stderr_r)
data, err = bufErr.ReadString('\n') data, err = bufErr.ReadString('\n')
assert.Nil(err, "should have no problem reading stdout") assert.Nil(err, "should have no problem reading stderr")
assert.Equal(data, "errfoo\n", "should be correct stdout") assert.Equal(data, "errfoo\n", "should be correct stderr")
// Test that we can write to stdin // Test that we can write to stdin
bufIn := bufio.NewReader(c.startIn) stdin_w.Write([]byte("infoo\n"))
rc.Stdin.Write([]byte("infoo\n")) bufIn := bufio.NewReader(c.startCmd.Stdin)
data, err = bufIn.ReadString('\n') data, err = bufIn.ReadString('\n')
assert.Nil(err, "should have no problem reading stdin") assert.Nil(err, "should have no problem reading stdin")
assert.Equal(data, "infoo\n", "should be correct stdin") assert.Equal(data, "infoo\n", "should be correct stdin")
// Test that we can get the exit status properly // Test that we can get the exit status properly
*c.startExitStatus = 42 // TODO
*c.startExited = true
rc.Wait()
assert.Equal(rc.ExitStatus, 42, "should have proper exit status")
// Test that we can upload things // Test that we can upload things
uploadR, uploadW := io.Pipe() uploadR, uploadW := io.Pipe()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment