Commit 2da59305 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

Merge pull request #701 from mitchellh/rpc-refactor

RPC happens over single TCP connection per plugin now

Three benefits:
* Single file descriptor per plugin
* NAT-friendly since plugins don't have to dial back in to the host
* Opens the foundation that we can easily use Unix domain sockets and such

A handful of Packer users were having issues with highly parallel (many builder/provisioner) templates where their systems would quickly reach their default file descriptor limits. This was because the previous mechanism would use a single TCP connection per RPC server, and Packer needs many (one per interface, basically). 

This merges in a MuxConn that multiplexes many "streams" on top of a single io.ReadWriteCloser. The RPC system has been revamped to know about this and use unique stream IDs to send everything over a single connection per plugin.

Previously, the RPC mechanism would sometimes send an address to the remote end and expect the remote end to connect back to it. While Packer shouldn't run remotely, some firewalls were having issues. This should be gone.

Finally, it should be possible now to optimize and use Unix domain sockets on Unix systems, avoiding ports and firewalls altogether.
parents 958f4191 8a24c9b1
......@@ -10,7 +10,6 @@ import (
"io/ioutil"
"log"
"net"
"net/rpc"
"os"
"os/exec"
"strings"
......@@ -130,56 +129,56 @@ func (c *Client) Exited() bool {
// Returns a builder implementation that is communicating over this
// client. If the client hasn't been started, this will start it.
func (c *Client) Builder() (packer.Builder, error) {
client, err := c.rpcClient()
client, err := c.packrpcClient()
if err != nil {
return nil, err
}
return &cmdBuilder{packrpc.Builder(client), c}, nil
return &cmdBuilder{client.Builder(), c}, nil
}
// Returns a command implementation that is communicating over this
// client. If the client hasn't been started, this will start it.
func (c *Client) Command() (packer.Command, error) {
client, err := c.rpcClient()
client, err := c.packrpcClient()
if err != nil {
return nil, err
}
return &cmdCommand{packrpc.Command(client), c}, nil
return &cmdCommand{client.Command(), c}, nil
}
// Returns a hook implementation that is communicating over this
// client. If the client hasn't been started, this will start it.
func (c *Client) Hook() (packer.Hook, error) {
client, err := c.rpcClient()
client, err := c.packrpcClient()
if err != nil {
return nil, err
}
return &cmdHook{packrpc.Hook(client), c}, nil
return &cmdHook{client.Hook(), c}, nil
}
// Returns a post-processor implementation that is communicating over
// this client. If the client hasn't been started, this will start it.
func (c *Client) PostProcessor() (packer.PostProcessor, error) {
client, err := c.rpcClient()
client, err := c.packrpcClient()
if err != nil {
return nil, err
}
return &cmdPostProcessor{packrpc.PostProcessor(client), c}, nil
return &cmdPostProcessor{client.PostProcessor(), c}, nil
}
// Returns a provisioner implementation that is communicating over this
// client. If the client hasn't been started, this will start it.
func (c *Client) Provisioner() (packer.Provisioner, error) {
client, err := c.rpcClient()
client, err := c.packrpcClient()
if err != nil {
return nil, err
}
return &cmdProvisioner{packrpc.Provisioner(client), c}, nil
return &cmdProvisioner{client.Provisioner(), c}, nil
}
// End the executing subprocess (if it is running) and perform any cleanup
......@@ -361,7 +360,7 @@ func (c *Client) logStderr(r io.Reader) {
close(c.doneLogging)
}
func (c *Client) rpcClient() (*rpc.Client, error) {
func (c *Client) packrpcClient() (*packrpc.Client, error) {
address, err := c.Start()
if err != nil {
return nil, err
......@@ -376,5 +375,11 @@ func (c *Client) rpcClient() (*rpc.Client, error) {
tcpConn := conn.(*net.TCPConn)
tcpConn.SetKeepAlive(true)
return rpc.NewClient(tcpConn), nil
client, err := packrpc.NewClient(tcpConn)
if err != nil {
tcpConn.Close()
return nil, err
}
return client, nil
}
......@@ -14,7 +14,6 @@ import (
packrpc "github.com/mitchellh/packer/packer/rpc"
"log"
"net"
"net/rpc"
"os"
"os/signal"
"runtime"
......@@ -35,13 +34,14 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69
// know how to speak it.
const APIVersion = "1"
// This serves a single RPC connection on the given RPC server on
// a random port.
func serve(server *rpc.Server) (err error) {
// Server waits for a connection to this plugin and returns a Packer
// RPC server that you can use to register components and serve them.
func Server() (*packrpc.Server, error) {
log.Printf("Plugin build against Packer '%s'", packer.GitCommit)
if os.Getenv(MagicCookieKey) != MagicCookieValue {
return errors.New("Please do not execute plugins directly. Packer will execute these for you.")
return nil, errors.New(
"Please do not execute plugins directly. Packer will execute these for you.")
}
// If there is no explicit number of Go threads to use, then set it
......@@ -51,12 +51,12 @@ func serve(server *rpc.Server) (err error) {
minPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MIN_PORT"), 10, 32)
if err != nil {
return
return nil, err
}
maxPort, err := strconv.ParseInt(os.Getenv("PACKER_PLUGIN_MAX_PORT"), 10, 32)
if err != nil {
return
return nil, err
}
log.Printf("Plugin minimum port: %d\n", minPort)
......@@ -77,7 +77,6 @@ func serve(server *rpc.Server) (err error) {
break
}
defer listener.Close()
// Output the address to stdout
......@@ -90,102 +89,22 @@ func serve(server *rpc.Server) (err error) {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %s\n", err.Error())
return
return nil, err
}
// Serve a single connection
log.Println("Serving a plugin connection...")
server.ServeConn(conn)
return
}
// Registers a signal handler to swallow and count interrupts so that the
// plugin isn't killed. The main host Packer process is responsible
// for killing the plugins when interrupted.
func countInterrupts() {
// Eat the interrupts
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
go func() {
var count int32 = 0
for {
<-ch
newCount := atomic.AddInt32(&Interrupts, 1)
newCount := atomic.AddInt32(&count, 1)
log.Printf("Received interrupt signal (count: %d). Ignoring.", newCount)
}
}()
}
// Serves a builder from a plugin.
func ServeBuilder(builder packer.Builder) {
log.Println("Preparing to serve a builder plugin...")
server := rpc.NewServer()
packrpc.RegisterBuilder(server, builder)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a command from a plugin.
func ServeCommand(command packer.Command) {
log.Println("Preparing to serve a command plugin...")
server := rpc.NewServer()
packrpc.RegisterCommand(server, command)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a hook from a plugin.
func ServeHook(hook packer.Hook) {
log.Println("Preparing to serve a hook plugin...")
server := rpc.NewServer()
packrpc.RegisterHook(server, hook)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a post-processor from a plugin.
func ServePostProcessor(p packer.PostProcessor) {
log.Println("Preparing to serve a post-processor plugin...")
server := rpc.NewServer()
packrpc.RegisterPostProcessor(server, p)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Serves a provisioner from a plugin.
func ServeProvisioner(p packer.Provisioner) {
log.Println("Preparing to serve a provisioner plugin...")
server := rpc.NewServer()
packrpc.RegisterProvisioner(server, p)
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
}
}
// Tests whether or not the plugin was interrupted or not.
func Interrupted() bool {
return atomic.LoadInt32(&Interrupts) > 0
// Serve a single connection
log.Println("Serving a plugin connection...")
return packrpc.NewServer(conn), nil
}
......@@ -54,20 +54,50 @@ func TestHelperProcess(*testing.T) {
fmt.Printf("%s1|:1234\n", APIVersion)
<-make(chan int)
case "builder":
ServeBuilder(new(packer.MockBuilder))
server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterBuilder(new(packer.MockBuilder))
server.Serve()
case "command":
ServeCommand(new(helperCommand))
server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterCommand(new(helperCommand))
server.Serve()
case "hook":
ServeHook(new(packer.MockHook))
server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterHook(new(packer.MockHook))
server.Serve()
case "invalid-rpc-address":
fmt.Println("lolinvalid")
case "mock":
fmt.Printf("%s|:1234\n", APIVersion)
<-make(chan int)
case "post-processor":
ServePostProcessor(new(helperPostProcessor))
server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterPostProcessor(new(helperPostProcessor))
server.Serve()
case "provisioner":
ServeProvisioner(new(packer.MockProvisioner))
server, err := Server()
if err != nil {
log.Printf("[ERR] %s", err)
os.Exit(1)
}
server.RegisterProvisioner(new(packer.MockProvisioner))
server.Serve()
case "start-timeout":
time.Sleep(1 * time.Minute)
os.Exit(1)
......
......@@ -9,6 +9,7 @@ import (
// available over an RPC connection.
type artifact struct {
client *rpc.Client
endpoint string
}
// ArtifactServer wraps a packer.Artifact implementation and makes it
......@@ -17,33 +18,29 @@ type ArtifactServer struct {
artifact packer.Artifact
}
func Artifact(client *rpc.Client) *artifact {
return &artifact{client}
}
func (a *artifact) BuilderId() (result string) {
a.client.Call("Artifact.BuilderId", new(interface{}), &result)
a.client.Call(a.endpoint+".BuilderId", new(interface{}), &result)
return
}
func (a *artifact) Files() (result []string) {
a.client.Call("Artifact.Files", new(interface{}), &result)
a.client.Call(a.endpoint+".Files", new(interface{}), &result)
return
}
func (a *artifact) Id() (result string) {
a.client.Call("Artifact.Id", new(interface{}), &result)
a.client.Call(a.endpoint+".Id", new(interface{}), &result)
return
}
func (a *artifact) String() (result string) {
a.client.Call("Artifact.String", new(interface{}), &result)
a.client.Call(a.endpoint+".String", new(interface{}), &result)
return
}
func (a *artifact) Destroy() error {
var result error
if err := a.client.Call("Artifact.Destroy", new(interface{}), &result); err != nil {
if err := a.client.Call(a.endpoint+".Destroy", new(interface{}), &result); err != nil {
return err
}
......
......@@ -2,48 +2,21 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
type testArtifact struct{}
func (testArtifact) BuilderId() string {
return "bid"
}
func (testArtifact) Files() []string {
return []string{"a", "b"}
}
func (testArtifact) Id() string {
return "id"
}
func (testArtifact) String() string {
return "string"
}
func (testArtifact) Destroy() error {
return nil
}
func TestArtifactRPC(t *testing.T) {
// Create the interface to test
a := new(testArtifact)
a := new(packer.MockArtifact)
// Start the server
server := rpc.NewServer()
RegisterArtifact(server, a)
address := serveSingleConn(server)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterArtifact(a)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
aClient := Artifact(client)
aClient := client.Artifact()
// Test
if aClient.BuilderId() != "bid" {
......@@ -64,5 +37,5 @@ func TestArtifactRPC(t *testing.T) {
}
func TestArtifact_Implements(t *testing.T) {
var _ packer.Artifact = Artifact(nil)
var _ packer.Artifact = new(artifact)
}
......@@ -9,16 +9,14 @@ import (
// over an RPC connection.
type build struct {
client *rpc.Client
mux *MuxConn
}
// BuildServer wraps a packer.Build implementation and makes it exportable
// as part of a Golang RPC server.
type BuildServer struct {
build packer.Build
}
type BuildRunArgs struct {
UiRPCAddress string
mux *MuxConn
}
type BuildPrepareResponse struct {
......@@ -26,10 +24,6 @@ type BuildPrepareResponse struct {
Error error
}
func Build(client *rpc.Client) *build {
return &build{client}
}
func (b *build) Name() (result string) {
b.client.Call("Build.Name", new(interface{}), &result)
return
......@@ -45,25 +39,25 @@ func (b *build) Prepare(v map[string]string) ([]string, error) {
}
func (b *build) Run(ui packer.Ui, cache packer.Cache) ([]packer.Artifact, error) {
// Create and start the server for the UI
server := rpc.NewServer()
RegisterCache(server, cache)
RegisterUi(server, ui)
args := &BuildRunArgs{serveSingleConn(server)}
var result []string
if err := b.client.Call("Build.Run", args, &result); err != nil {
nextId := b.mux.NextId()
server := NewServerWithMux(b.mux, nextId)
server.RegisterCache(cache)
server.RegisterUi(ui)
go server.Serve()
var result []uint32
if err := b.client.Call("Build.Run", nextId, &result); err != nil {
return nil, err
}
artifacts := make([]packer.Artifact, len(result))
for i, addr := range result {
client, err := rpcDial(addr)
for i, streamId := range result {
client, err := NewClientWithMux(b.mux, streamId)
if err != nil {
return nil, err
}
artifacts[i] = Artifact(client)
artifacts[i] = client.Artifact()
}
return artifacts, nil
......@@ -101,22 +95,26 @@ func (b *BuildServer) Prepare(v map[string]string, resp *BuildPrepareResponse) e
return nil
}
func (b *BuildServer) Run(args *BuildRunArgs, reply *[]string) error {
client, err := rpcDial(args.UiRPCAddress)
func (b *BuildServer) Run(streamId uint32, reply *[]uint32) error {
client, err := NewClientWithMux(b.mux, streamId)
if err != nil {
return err
return NewBasicError(err)
}
defer client.Close()
artifacts, err := b.build.Run(&Ui{client}, Cache(client))
artifacts, err := b.build.Run(client.Ui(), client.Cache())
if err != nil {
return NewBasicError(err)
}
*reply = make([]string, len(artifacts))
*reply = make([]uint32, len(artifacts))
for i, artifact := range artifacts {
server := rpc.NewServer()
RegisterArtifact(server, artifact)
(*reply)[i] = serveSingleConn(server)
streamId := b.mux.NextId()
server := NewServerWithMux(b.mux, streamId)
server.RegisterArtifact(artifact)
go server.Serve()
(*reply)[i] = streamId
}
return nil
......
......@@ -3,12 +3,11 @@ package rpc
import (
"errors"
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
var testBuildArtifact = &testArtifact{}
var testBuildArtifact = &packer.MockArtifact{}
type testBuild struct {
nameCalled bool
......@@ -60,25 +59,13 @@ func (b *testBuild) Cancel() {
b.cancelCalled = true
}
func buildRPCClient(t *testing.T) (*testBuild, packer.Build) {
// Create the interface to test
b := new(testBuild)
// Start the server
server := rpc.NewServer()
RegisterBuild(server, b)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
return b, Build(client)
}
func TestBuild(t *testing.T) {
b, bClient := buildRPCClient(t)
b := new(testBuild)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuild(b)
bClient := client.Build()
// Test Name
bClient.Name()
......@@ -120,23 +107,6 @@ func TestBuild(t *testing.T) {
t.Fatalf("bad: %#v", artifacts)
}
// Test the UI given to run, which should be fully functional
if b.runCalled {
b.runCache.Lock("foo")
if !cache.lockCalled {
t.Fatal("lock shuld be called")
}
b.runUi.Say("format")
if !ui.sayCalled {
t.Fatal("say should be called")
}
if ui.sayMessage != "format" {
t.Fatalf("bad: %#v", ui.sayMessage)
}
}
// Test run with an error
b.errRunResult = true
_, err = bClient.Run(ui, cache)
......@@ -164,7 +134,12 @@ func TestBuild(t *testing.T) {
}
func TestBuildPrepare_Warnings(t *testing.T) {
b, bClient := buildRPCClient(t)
b := new(testBuild)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuild(b)
bClient := client.Build()
expected := []string{"foo"}
b.prepareWarnings = expected
......@@ -179,5 +154,5 @@ func TestBuildPrepare_Warnings(t *testing.T) {
}
func TestBuild_ImplementsBuild(t *testing.T) {
var _ packer.Build = Build(nil)
var _ packer.Build = new(build)
}
package rpc
import (
"encoding/gob"
"fmt"
"github.com/mitchellh/packer/packer"
"log"
"net/rpc"
......@@ -12,37 +10,25 @@ import (
// over an RPC connection.
type builder struct {
client *rpc.Client
mux *MuxConn
}
// BuilderServer wraps a packer.Builder implementation and makes it exportable
// as part of a Golang RPC server.
type BuilderServer struct {
builder packer.Builder
mux *MuxConn
}
type BuilderPrepareArgs struct {
Configs []interface{}
}
type BuilderRunArgs struct {
RPCAddress string
ResponseAddress string
}
type BuilderPrepareResponse struct {
Warnings []string
Error error
}
type BuilderRunResponse struct {
Err error
RPCAddress string
}
func Builder(client *rpc.Client) *builder {
return &builder{client}
}
func (b *builder) Prepare(config ...interface{}) ([]string, error) {
var resp BuilderPrepareResponse
cerr := b.client.Call("Builder.Prepare", &BuilderPrepareArgs{config}, &resp)
......@@ -54,58 +40,28 @@ func (b *builder) Prepare(config ...interface{}) ([]string, error) {
}
func (b *builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) {
// Create and start the server for the Build and UI
server := rpc.NewServer()
RegisterCache(server, cache)
RegisterHook(server, hook)
RegisterUi(server, ui)
// Create a server for the response
responseL := netListenerInRange(portRangeMin, portRangeMax)
runResponseCh := make(chan *BuilderRunResponse)
go func() {
defer responseL.Close()
var response BuilderRunResponse
defer func() { runResponseCh <- &response }()
conn, err := responseL.Accept()
if err != nil {
response.Err = err
return
}
defer conn.Close()
decoder := gob.NewDecoder(conn)
if err := decoder.Decode(&response); err != nil {
response.Err = fmt.Errorf("Error waiting for Run: %s", err)
}
}()
args := &BuilderRunArgs{
serveSingleConn(server),
responseL.Addr().String(),
}
if err := b.client.Call("Builder.Run", args, new(interface{})); err != nil {
nextId := b.mux.NextId()
server := NewServerWithMux(b.mux, nextId)
server.RegisterCache(cache)
server.RegisterHook(hook)
server.RegisterUi(ui)
go server.Serve()
var responseId uint32
if err := b.client.Call("Builder.Run", nextId, &responseId); err != nil {
return nil, err
}
response := <-runResponseCh
if response.Err != nil {
return nil, response.Err
}
if response.RPCAddress == "" {
if responseId == 0 {
return nil, nil
}
client, err := rpcDial(response.RPCAddress)
client, err := NewClientWithMux(b.mux, responseId)
if err != nil {
return nil, err
}
return Artifact(client), nil
return client.Artifact(), nil
}
func (b *builder) Cancel() {
......@@ -127,45 +83,26 @@ func (b *BuilderServer) Prepare(args *BuilderPrepareArgs, reply *BuilderPrepareR
return nil
}
func (b *BuilderServer) Run(args *BuilderRunArgs, reply *interface{}) error {
client, err := rpcDial(args.RPCAddress)
func (b *BuilderServer) Run(streamId uint32, reply *uint32) error {
client, err := NewClientWithMux(b.mux, streamId)
if err != nil {
return err
return NewBasicError(err)
}
defer client.Close()
responseC, err := tcpDial(args.ResponseAddress)
artifact, err := b.builder.Run(client.Ui(), client.Hook(), client.Cache())
if err != nil {
return err
return NewBasicError(err)
}
responseWriter := gob.NewEncoder(responseC)
// Run the build in a goroutine so we don't block the RPC connection
go func() {
defer responseC.Close()
cache := Cache(client)
hook := Hook(client)
ui := &Ui{client}
artifact, responseErr := b.builder.Run(ui, hook, cache)
responseAddress := ""
if responseErr == nil && artifact != nil {
// Wrap the artifact
server := rpc.NewServer()
RegisterArtifact(server, artifact)
responseAddress = serveSingleConn(server)
}
if responseErr != nil {
responseErr = NewBasicError(responseErr)
}
err := responseWriter.Encode(&BuilderRunResponse{responseErr, responseAddress})
if err != nil {
log.Printf("BuildServer.Run error: %s", err)
*reply = 0
if artifact != nil {
streamId = b.mux.NextId()
server := NewServerWithMux(b.mux, streamId)
server.RegisterArtifact(artifact)
go server.Serve()
*reply = streamId
}
}()
return nil
}
......
......@@ -2,31 +2,19 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
var testBuilderArtifact = &testArtifact{}
func builderRPCClient(t *testing.T) (*packer.MockBuilder, packer.Builder) {
b := new(packer.MockBuilder)
// Start the server
server := rpc.NewServer()
RegisterBuilder(server, b)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
return b, Builder(client)
}
var testBuilderArtifact = &packer.MockArtifact{}
func TestBuilderPrepare(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
// Test Prepare
config := 42
......@@ -48,7 +36,12 @@ func TestBuilderPrepare(t *testing.T) {
}
func TestBuilderPrepare_Warnings(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
expected := []string{"foo"}
b.PrepareWarnings = expected
......@@ -64,7 +57,12 @@ func TestBuilderPrepare_Warnings(t *testing.T) {
}
func TestBuilderRun(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
// Test Run
cache := new(testCache)
......@@ -79,34 +77,21 @@ func TestBuilderRun(t *testing.T) {
t.Fatal("run should be called")
}
b.RunCache.Lock("foo")
if !cache.lockCalled {
t.Fatal("should be called")
}
b.RunHook.Run("foo", nil, nil, nil)
if !hook.RunCalled {
t.Fatal("should be called")
}
b.RunUi.Say("format")
if !ui.sayCalled {
t.Fatal("say should be called")
}
if ui.sayMessage != "format" {
t.Fatalf("bad: %s", ui.sayMessage)
}
if artifact.Id() != testBuilderArtifact.Id() {
t.Fatalf("bad: %s", artifact.Id())
}
}
func TestBuilderRun_nilResult(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
b.RunNilResult = true
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
cache := new(testCache)
hook := &packer.MockHook{}
ui := &testUi{}
......@@ -120,7 +105,13 @@ func TestBuilderRun_nilResult(t *testing.T) {
}
func TestBuilderRun_ErrResult(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
b.RunErrResult = true
cache := new(testCache)
......@@ -136,7 +127,12 @@ func TestBuilderRun_ErrResult(t *testing.T) {
}
func TestBuilderCancel(t *testing.T) {
b, bClient := builderRPCClient(t)
b := new(packer.MockBuilder)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterBuilder(b)
bClient := client.Builder()
bClient.Cancel()
if !b.CancelCalled {
......@@ -145,5 +141,5 @@ func TestBuilderCancel(t *testing.T) {
}
func TestBuilder_ImplementsBuilder(t *testing.T) {
var _ packer.Builder = Builder(nil)
var _ packer.Builder = new(builder)
}
......@@ -17,10 +17,6 @@ type CacheServer struct {
cache packer.Cache
}
func Cache(client *rpc.Client) *cache {
return &cache{client}
}
type CacheRLockResponse struct {
Path string
Exists bool
......
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"testing"
)
......@@ -40,11 +39,7 @@ func (t *testCache) RUnlock(key string) {
}
func TestCache_Implements(t *testing.T) {
var raw interface{}
raw = Cache(nil)
if _, ok := raw.(packer.Cache); !ok {
t.Fatal("Cache must be a cache.")
}
var _ packer.Cache = new(cache)
}
func TestCacheRPC(t *testing.T) {
......@@ -52,19 +47,15 @@ func TestCacheRPC(t *testing.T) {
c := new(testCache)
// Start the server
server := rpc.NewServer()
RegisterCache(server, c)
address := serveSingleConn(server)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterCache(c)
// Create the client over RPC and run some methods to verify it works
rpcClient, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("bad: %s", err)
}
client := Cache(rpcClient)
cacheClient := client.Cache()
// Test Lock
client.Lock("foo")
cacheClient.Lock("foo")
if !c.lockCalled {
t.Fatal("should be called")
}
......@@ -73,7 +64,7 @@ func TestCacheRPC(t *testing.T) {
}
// Test Unlock
client.Unlock("foo")
cacheClient.Unlock("foo")
if !c.unlockCalled {
t.Fatal("should be called")
}
......@@ -82,7 +73,7 @@ func TestCacheRPC(t *testing.T) {
}
// Test RLock
client.RLock("foo")
cacheClient.RLock("foo")
if !c.rlockCalled {
t.Fatal("should be called")
}
......@@ -91,7 +82,7 @@ func TestCacheRPC(t *testing.T) {
}
// Test RUnlock
client.RUnlock("foo")
cacheClient.RUnlock("foo")
if !c.runlockCalled {
t.Fatal("should be called")
}
......
package rpc
import (
"github.com/mitchellh/packer/packer"
"io"
"net/rpc"
)
// Client is the client end that communicates with a Packer RPC server.
// Establishing a connection is up to the user, the Client can just
// communicate over any ReadWriteCloser.
type Client struct {
mux *MuxConn
client *rpc.Client
}
func NewClient(rwc io.ReadWriteCloser) (*Client, error) {
return NewClientWithMux(NewMuxConn(rwc), 0)
}
func NewClientWithMux(mux *MuxConn, streamId uint32) (*Client, error) {
clientConn, err := mux.Dial(streamId)
if err != nil {
return nil, err
}
return &Client{
mux: mux,
client: rpc.NewClient(clientConn),
}, nil
}
func (c *Client) Close() error {
if err := c.client.Close(); err != nil {
return err
}
return nil
}
func (c *Client) Artifact() packer.Artifact {
return &artifact{
client: c.client,
endpoint: DefaultArtifactEndpoint,
}
}
func (c *Client) Build() packer.Build {
return &build{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Builder() packer.Builder {
return &builder{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Cache() packer.Cache {
return &cache{
client: c.client,
}
}
func (c *Client) Command() packer.Command {
return &command{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Communicator() packer.Communicator {
return &communicator{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Environment() packer.Environment {
return &Environment{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Hook() packer.Hook {
return &hook{
client: c.client,
mux: c.mux,
}
}
func (c *Client) PostProcessor() packer.PostProcessor {
return &postProcessor{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Provisioner() packer.Provisioner {
return &provisioner{
client: c.client,
mux: c.mux,
}
}
func (c *Client) Ui() packer.Ui {
return &Ui{
client: c.client,
endpoint: DefaultUiEndpoint,
}
}
package rpc
import (
"net"
"testing"
)
func testConn(t *testing.T) (net.Conn, net.Conn) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err: %s", err)
}
var serverConn net.Conn
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
defer l.Close()
var err error
serverConn, err = l.Accept()
if err != nil {
t.Fatalf("err: %s", err)
}
}()
clientConn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("err: %s", err)
}
<-doneCh
return clientConn, serverConn
}
func testClientServer(t *testing.T) (*Client, *Server) {
clientConn, serverConn := testConn(t)
server := NewServer(serverConn)
go server.Serve()
client, err := NewClient(clientConn)
if err != nil {
server.Close()
t.Fatalf("err: %s", err)
}
return client, server
}
......@@ -9,25 +9,23 @@ import (
// command is actually executed over an RPC connection.
type command struct {
client *rpc.Client
mux *MuxConn
}
// A CommandServer wraps a packer.Command and makes it exportable as part
// of a Golang RPC server.
type CommandServer struct {
command packer.Command
mux *MuxConn
}
type CommandRunArgs struct {
RPCAddress string
Args []string
StreamId uint32
}
type CommandSynopsisArgs byte
func Command(client *rpc.Client) *command {
return &command{client}
}
func (c *command) Help() (result string) {
err := c.client.Call("Command.Help", new(interface{}), &result)
if err != nil {
......@@ -38,11 +36,15 @@ func (c *command) Help() (result string) {
}
func (c *command) Run(env packer.Environment, args []string) (result int) {
// Create and start the server for the Environment
server := rpc.NewServer()
RegisterEnvironment(server, env)
rpcArgs := &CommandRunArgs{serveSingleConn(server), args}
nextId := c.mux.NextId()
server := NewServerWithMux(c.mux, nextId)
server.RegisterEnvironment(env)
go server.Serve()
rpcArgs := &CommandRunArgs{
Args: args,
StreamId: nextId,
}
err := c.client.Call("Command.Run", rpcArgs, &result)
if err != nil {
panic(err)
......@@ -66,14 +68,13 @@ func (c *CommandServer) Help(args *interface{}, reply *string) error {
}
func (c *CommandServer) Run(args *CommandRunArgs, reply *int) error {
client, err := rpcDial(args.RPCAddress)
client, err := NewClientWithMux(c.mux, args.StreamId)
if err != nil {
return err
return NewBasicError(err)
}
defer client.Close()
env := &Environment{client}
*reply = c.command.Run(env, args.Args)
*reply = c.command.Run(client.Environment(), args.Args)
return nil
}
......
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
......@@ -33,21 +32,14 @@ func TestRPCCommand(t *testing.T) {
command := new(TestCommand)
// Start the server
server := rpc.NewServer()
RegisterCommand(server, command)
address := serveSingleConn(server)
// Create the command client over RPC and run some methods to verify
// we get the proper behavior.
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
clientComm := Command(client)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterCommand(command)
commClient := client.Command()
//Test Help
help := clientComm.Help()
help := commClient.Help()
if help != "bar" {
t.Fatalf("bad: %s", help)
}
......@@ -55,7 +47,7 @@ func TestRPCCommand(t *testing.T) {
// Test run
runArgs := []string{"foo", "bar"}
testEnv := &testEnvironment{}
exitCode := clientComm.Run(testEnv, runArgs)
exitCode := commClient.Run(testEnv, runArgs)
if !reflect.DeepEqual(command.runArgs, runArgs) {
t.Fatalf("bad: %#v", command.runArgs)
}
......@@ -67,18 +59,13 @@ func TestRPCCommand(t *testing.T) {
t.Fatal("runEnv should not be nil")
}
command.runEnv.Ui()
if !testEnv.uiCalled {
t.Fatal("ui should be called")
}
// Test Synopsis
synopsis := clientComm.Synopsis()
synopsis := commClient.Synopsis()
if synopsis != "foo" {
t.Fatalf("bad: %#v", synopsis)
}
}
func TestCommand_Implements(t *testing.T) {
var _ packer.Command = Command(nil)
var _ packer.Command = new(command)
}
......@@ -2,11 +2,9 @@ package rpc
import (
"encoding/gob"
"errors"
"github.com/mitchellh/packer/packer"
"io"
"log"
"net"
"net/rpc"
)
......@@ -14,12 +12,14 @@ import (
// executed over an RPC connection.
type communicator struct {
client *rpc.Client
mux *MuxConn
}
// CommunicatorServer wraps a packer.Communicator implementation and makes
// it exportable as part of a Golang RPC server.
type CommunicatorServer struct {
c packer.Communicator
mux *MuxConn
}
type CommandFinished struct {
......@@ -28,20 +28,20 @@ type CommandFinished struct {
type CommunicatorStartArgs struct {
Command string
StdinAddress string
StdoutAddress string
StderrAddress string
ResponseAddress string
StdinStreamId uint32
StdoutStreamId uint32
StderrStreamId uint32
ResponseStreamId uint32
}
type CommunicatorDownloadArgs struct {
Path string
WriterAddress string
WriterStreamId uint32
}
type CommunicatorUploadArgs struct {
Path string
ReaderAddress string
ReaderStreamId uint32
}
type CommunicatorUploadDirArgs struct {
......@@ -51,7 +51,7 @@ type CommunicatorUploadDirArgs struct {
}
func Communicator(client *rpc.Client) *communicator {
return &communicator{client}
return &communicator{client: client}
}
func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
......@@ -59,45 +59,43 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
args.Command = cmd.Command
if cmd.Stdin != nil {
stdinL := netListenerInRange(portRangeMin, portRangeMax)
args.StdinAddress = stdinL.Addr().String()
go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin)
args.StdinStreamId = c.mux.NextId()
go serveSingleCopy("stdin", c.mux, args.StdinStreamId, nil, cmd.Stdin)
}
if cmd.Stdout != nil {
stdoutL := netListenerInRange(portRangeMin, portRangeMax)
args.StdoutAddress = stdoutL.Addr().String()
go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil)
args.StdoutStreamId = c.mux.NextId()
go serveSingleCopy("stdout", c.mux, args.StdoutStreamId, cmd.Stdout, nil)
}
if cmd.Stderr != nil {
stderrL := netListenerInRange(portRangeMin, portRangeMax)
args.StderrAddress = stderrL.Addr().String()
go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil)
args.StderrStreamId = c.mux.NextId()
go serveSingleCopy("stderr", c.mux, args.StderrStreamId, cmd.Stderr, nil)
}
responseL := netListenerInRange(portRangeMin, portRangeMax)
args.ResponseAddress = responseL.Addr().String()
responseStreamId := c.mux.NextId()
args.ResponseStreamId = responseStreamId
go func() {
defer responseL.Close()
conn, err := responseL.Accept()
conn, err := c.mux.Accept(responseStreamId)
if err != nil {
log.Printf("[ERR] Error accepting response stream %d: %s",
responseStreamId, err)
cmd.SetExited(123)
return
}
defer conn.Close()
decoder := gob.NewDecoder(conn)
var finished CommandFinished
decoder := gob.NewDecoder(conn)
if err := decoder.Decode(&finished); err != nil {
log.Printf("[ERR] Error decoding response stream %d: %s",
responseStreamId, err)
cmd.SetExited(123)
return
}
log.Printf("[INFO] RPC client: Communicator ended with: %d", finished.ExitStatus)
cmd.SetExited(finished.ExitStatus)
}()
......@@ -106,23 +104,13 @@ func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
}
func (c *communicator) Upload(path string, r io.Reader) (err error) {
// We need to create a server that can proxy the reader data
// over because we can't simply gob encode an io.Reader
readerL := netListenerInRange(portRangeMin, portRangeMax)
if readerL == nil {
err = errors.New("couldn't allocate listener for upload reader")
return
}
// Make sure at the end of this call, we close the listener
defer readerL.Close()
// Pipe the reader through to the connection
go serveSingleCopy("uploadReader", readerL, nil, r)
streamId := c.mux.NextId()
go serveSingleCopy("uploadData", c.mux, streamId, nil, r)
args := CommunicatorUploadArgs{
path,
readerL.Addr().String(),
Path: path,
ReaderStreamId: streamId,
}
err = c.client.Call("Communicator.Upload", &args, new(interface{}))
......@@ -146,99 +134,104 @@ func (c *communicator) UploadDir(dst string, src string, exclude []string) error
}
func (c *communicator) Download(path string, w io.Writer) (err error) {
// We need to create a server that can proxy that data downloaded
// into the writer because we can't gob encode a writer directly.
writerL := netListenerInRange(portRangeMin, portRangeMax)
if writerL == nil {
err = errors.New("couldn't allocate listener for download writer")
return
}
// Make sure we close the listener once we're done because we'll be done
defer writerL.Close()
// Serve a single connection and a single copy
go serveSingleCopy("downloadWriter", writerL, w, nil)
streamId := c.mux.NextId()
go serveSingleCopy("downloadWriter", c.mux, streamId, w, nil)
args := CommunicatorDownloadArgs{
path,
writerL.Addr().String(),
Path: path,
WriterStreamId: streamId,
}
err = c.client.Call("Communicator.Download", &args, new(interface{}))
return
}
func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) {
func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (error) {
// Build the RemoteCmd on this side so that it all pipes over
// to the remote side.
var cmd packer.RemoteCmd
cmd.Command = args.Command
toClose := make([]net.Conn, 0)
if args.StdinAddress != "" {
stdinC, err := tcpDial(args.StdinAddress)
// Create a channel to signal we're done so that we can close
// our stdin/stdout/stderr streams
toClose := make([]io.Closer, 0)
doneCh := make(chan struct{})
go func() {
<-doneCh
for _, conn := range toClose {
defer conn.Close()
}
}()
if args.StdinStreamId > 0 {
conn, err := c.mux.Dial(args.StdinStreamId)
if err != nil {
return err
close(doneCh)
return NewBasicError(err)
}
toClose = append(toClose, stdinC)
cmd.Stdin = stdinC
toClose = append(toClose, conn)
cmd.Stdin = conn
}
if args.StdoutAddress != "" {
stdoutC, err := tcpDial(args.StdoutAddress)
if args.StdoutStreamId > 0 {
conn, err := c.mux.Dial(args.StdoutStreamId)
if err != nil {
return err
close(doneCh)
return NewBasicError(err)
}
toClose = append(toClose, stdoutC)
cmd.Stdout = stdoutC
toClose = append(toClose, conn)
cmd.Stdout = conn
}
if args.StderrAddress != "" {
stderrC, err := tcpDial(args.StderrAddress)
if args.StderrStreamId > 0 {
conn, err := c.mux.Dial(args.StderrStreamId)
if err != nil {
return err
close(doneCh)
return NewBasicError(err)
}
toClose = append(toClose, stderrC)
cmd.Stderr = stderrC
toClose = append(toClose, conn)
cmd.Stderr = conn
}
// Connect to the response address so we can write our result to it
// when ready.
responseC, err := tcpDial(args.ResponseAddress)
responseC, err := c.mux.Dial(args.ResponseStreamId)
if err != nil {
return err
close(doneCh)
return NewBasicError(err)
}
responseWriter := gob.NewEncoder(responseC)
// Start the actual command
err = c.c.Start(&cmd)
if err != nil {
close(doneCh)
return NewBasicError(err)
}
// Start a goroutine to spin and wait for the process to actual
// exit. When it does, report it back to caller...
go func() {
defer close(doneCh)
defer responseC.Close()
for _, conn := range toClose {
defer conn.Close()
}
cmd.Wait()
log.Printf("[INFO] RPC endpoint: Communicator ended with: %d", cmd.ExitStatus)
responseWriter.Encode(&CommandFinished{cmd.ExitStatus})
}()
return
return nil
}
func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) {
readerC, err := tcpDial(args.ReaderAddress)
readerC, err := c.mux.Dial(args.ReaderStreamId)
if err != nil {
return
}
defer readerC.Close()
err = c.c.Upload(args.Path, readerC)
......@@ -250,21 +243,18 @@ func (c *CommunicatorServer) UploadDir(args *CommunicatorUploadDirArgs, reply *e
}
func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) {
writerC, err := tcpDial(args.WriterAddress)
writerC, err := c.mux.Dial(args.WriterStreamId)
if err != nil {
return
}
defer writerC.Close()
err = c.c.Download(args.Path, writerC)
return
}
func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) {
defer l.Close()
conn, err := l.Accept()
func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) {
conn, err := mux.Accept(id)
if err != nil {
log.Printf("'%s' accept error: %s", name, err)
return
......
......@@ -4,7 +4,6 @@ import (
"bufio"
"github.com/mitchellh/packer/packer"
"io"
"net/rpc"
"reflect"
"testing"
)
......@@ -14,16 +13,11 @@ func TestCommunicatorRPC(t *testing.T) {
c := new(packer.MockCommunicator)
// Start the server
server := rpc.NewServer()
RegisterCommunicator(server, c)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
remote := Communicator(client)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterCommunicator(c)
remote := client.Communicator()
// The remote command we'll use
stdin_r, stdin_w := io.Pipe()
......@@ -42,7 +36,7 @@ func TestCommunicatorRPC(t *testing.T) {
c.StartExitStatus = 42
// Test Start
err = remote.Start(&cmd)
err := remote.Start(&cmd)
if err != nil {
t.Fatalf("err: %s", err)
}
......@@ -74,7 +68,7 @@ func TestCommunicatorRPC(t *testing.T) {
stdin_w.Close()
cmd.Wait()
if c.StartStdin != "info\n" {
t.Fatalf("bad data: %s", data)
t.Fatalf("bad data: %s", c.StartStdin)
}
// Test that we can get the exit status properly
......
......@@ -2,6 +2,7 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"log"
"net/rpc"
)
......@@ -9,12 +10,14 @@ import (
// where the actual environment is executed over an RPC connection.
type Environment struct {
client *rpc.Client
mux *MuxConn
}
// A EnvironmentServer wraps a packer.Environment and makes it exportable
// as part of a Golang RPC server.
type EnvironmentServer struct {
env packer.Environment
mux *MuxConn
}
type EnvironmentCliArgs struct {
......@@ -22,33 +25,32 @@ type EnvironmentCliArgs struct {
}
func (e *Environment) Builder(name string) (b packer.Builder, err error) {
var reply string
err = e.client.Call("Environment.Builder", name, &reply)
var streamId uint32
err = e.client.Call("Environment.Builder", name, &streamId)
if err != nil {
return
}
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
return
return nil, err
}
b = Builder(client)
b = client.Builder()
return
}
func (e *Environment) Cache() packer.Cache {
var reply string
if err := e.client.Call("Environment.Cache", new(interface{}), &reply); err != nil {
var streamId uint32
if err := e.client.Call("Environment.Cache", new(interface{}), &streamId); err != nil {
panic(err)
}
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
panic(err)
log.Printf("[ERR] Error getting cache client: %s", err)
return nil
}
return Cache(client)
return client.Cache()
}
func (e *Environment) Cli(args []string) (result int, err error) {
......@@ -58,85 +60,81 @@ func (e *Environment) Cli(args []string) (result int, err error) {
}
func (e *Environment) Hook(name string) (h packer.Hook, err error) {
var reply string
err = e.client.Call("Environment.Hook", name, &reply)
var streamId uint32
err = e.client.Call("Environment.Hook", name, &streamId)
if err != nil {
return
}
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
return
return nil, err
}
h = Hook(client)
return
return client.Hook(), nil
}
func (e *Environment) PostProcessor(name string) (p packer.PostProcessor, err error) {
var reply string
err = e.client.Call("Environment.PostProcessor", name, &reply)
var streamId uint32
err = e.client.Call("Environment.PostProcessor", name, &streamId)
if err != nil {
return
}
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
return
return nil, err
}
p = PostProcessor(client)
p = client.PostProcessor()
return
}
func (e *Environment) Provisioner(name string) (p packer.Provisioner, err error) {
var reply string
err = e.client.Call("Environment.Provisioner", name, &reply)
var streamId uint32
err = e.client.Call("Environment.Provisioner", name, &streamId)
if err != nil {
return
}
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
return
return nil, err
}
p = Provisioner(client)
p = client.Provisioner()
return
}
func (e *Environment) Ui() packer.Ui {
var reply string
e.client.Call("Environment.Ui", new(interface{}), &reply)
var streamId uint32
e.client.Call("Environment.Ui", new(interface{}), &streamId)
client, err := rpcDial(reply)
client, err := NewClientWithMux(e.mux, streamId)
if err != nil {
panic(err)
log.Printf("[ERR] Error connecting to Ui: %s", err)
return nil
}
return &Ui{client}
return client.Ui()
}
func (e *EnvironmentServer) Builder(name *string, reply *string) error {
builder, err := e.env.Builder(*name)
func (e *EnvironmentServer) Builder(name string, reply *uint32) error {
builder, err := e.env.Builder(name)
if err != nil {
return err
return NewBasicError(err)
}
// Wrap it
server := rpc.NewServer()
RegisterBuilder(server, builder)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterBuilder(builder)
go server.Serve()
return nil
}
func (e *EnvironmentServer) Cache(args *interface{}, reply *string) error {
func (e *EnvironmentServer) Cache(args *interface{}, reply *uint32) error {
cache := e.env.Cache()
server := rpc.NewServer()
RegisterCache(server, cache)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterCache(cache)
go server.Serve()
return nil
}
......@@ -145,53 +143,51 @@ func (e *EnvironmentServer) Cli(args *EnvironmentCliArgs, reply *int) (err error
return
}
func (e *EnvironmentServer) Hook(name *string, reply *string) error {
hook, err := e.env.Hook(*name)
func (e *EnvironmentServer) Hook(name string, reply *uint32) error {
hook, err := e.env.Hook(name)
if err != nil {
return err
return NewBasicError(err)
}
// Wrap it
server := rpc.NewServer()
RegisterHook(server, hook)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterHook(hook)
go server.Serve()
return nil
}
func (e *EnvironmentServer) PostProcessor(name *string, reply *string) error {
pp, err := e.env.PostProcessor(*name)
func (e *EnvironmentServer) PostProcessor(name string, reply *uint32) error {
pp, err := e.env.PostProcessor(name)
if err != nil {
return err
return NewBasicError(err)
}
server := rpc.NewServer()
RegisterPostProcessor(server, pp)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterPostProcessor(pp)
go server.Serve()
return nil
}
func (e *EnvironmentServer) Provisioner(name *string, reply *string) error {
prov, err := e.env.Provisioner(*name)
func (e *EnvironmentServer) Provisioner(name string, reply *uint32) error {
prov, err := e.env.Provisioner(name)
if err != nil {
return err
return NewBasicError(err)
}
server := rpc.NewServer()
RegisterProvisioner(server, prov)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterProvisioner(prov)
go server.Serve()
return nil
}
func (e *EnvironmentServer) Ui(args *interface{}, reply *string) error {
func (e *EnvironmentServer) Ui(args *interface{}, reply *uint32) error {
ui := e.env.Ui()
// Wrap it
server := rpc.NewServer()
RegisterUi(server, ui)
*reply = serveSingleConn(server)
*reply = e.mux.NextId()
server := NewServerWithMux(e.mux, *reply)
server.RegisterUi(ui)
go server.Serve()
return nil
}
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
......@@ -69,16 +68,11 @@ func TestEnvironmentRPC(t *testing.T) {
e := &testEnvironment{}
// Start the server
server := rpc.NewServer()
RegisterEnvironment(server, e)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
eClient := &Environment{client}
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterEnvironment(e)
eClient := client.Environment()
// Test Builder
builder, _ := eClient.Builder("foo")
......
......@@ -10,32 +10,36 @@ import (
// over an RPC connection.
type hook struct {
client *rpc.Client
mux *MuxConn
}
// HookServer wraps a packer.Hook implementation and makes it exportable
// as part of a Golang RPC server.
type HookServer struct {
hook packer.Hook
mux *MuxConn
}
type HookRunArgs struct {
Name string
Data interface{}
RPCAddress string
}
func Hook(client *rpc.Client) *hook {
return &hook{client}
StreamId uint32
}
func (h *hook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error {
server := rpc.NewServer()
RegisterCommunicator(server, comm)
RegisterUi(server, ui)
address := serveSingleConn(server)
nextId := h.mux.NextId()
server := NewServerWithMux(h.mux, nextId)
server.RegisterCommunicator(comm)
server.RegisterUi(ui)
go server.Serve()
args := &HookRunArgs{name, data, address}
return h.client.Call("Hook.Run", args, new(interface{}))
args := HookRunArgs{
Name: name,
Data: data,
StreamId: nextId,
}
return h.client.Call("Hook.Run", &args, new(interface{}))
}
func (h *hook) Cancel() {
......@@ -46,12 +50,13 @@ func (h *hook) Cancel() {
}
func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error {
client, err := rpcDial(args.RPCAddress)
client, err := NewClientWithMux(h.mux, args.StreamId)
if err != nil {
return err
return NewBasicError(err)
}
defer client.Close()
if err := h.hook.Run(args.Name, &Ui{client}, Communicator(client), args.Data); err != nil {
if err := h.hook.Run(args.Name, client.Ui(), client.Communicator(), args.Data); err != nil {
return NewBasicError(err)
}
......
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"sync"
"testing"
......@@ -14,17 +13,11 @@ func TestHookRPC(t *testing.T) {
h := new(packer.MockHook)
// Serve
server := rpc.NewServer()
RegisterHook(server, h)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
hClient := Hook(client)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterHook(h)
hClient := client.Hook()
// Test Run
ui := &testUi{}
......@@ -60,17 +53,11 @@ func TestHook_cancelWhileRun(t *testing.T) {
}
// Serve
server := rpc.NewServer()
RegisterHook(server, h)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
hClient := Hook(client)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterHook(h)
hClient := client.Hook()
// Start the run
finished := make(chan struct{})
......
package rpc
import (
"encoding/binary"
"fmt"
"io"
"log"
"sync"
"time"
)
// MuxConn is a connection that can be used bi-directionally for RPC. Normally,
// Go RPC only allows client-to-server connections. This allows the client
// to actually act as a server as well.
//
// MuxConn works using a fairly dumb multiplexing technique of simply
// framing every piece of data sent into a prefix + data format. Streams
// are established using a subset of the TCP protocol. Only a subset is
// necessary since we assume ordering on the underlying RWC.
type MuxConn struct {
curId uint32
rwc io.ReadWriteCloser
streams map[uint32]*Stream
mu sync.RWMutex
wlock sync.Mutex
}
type muxPacketType byte
const (
muxPacketSyn muxPacketType = iota
muxPacketAck
muxPacketFin
muxPacketData
)
func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
m := &MuxConn{
rwc: rwc,
streams: make(map[uint32]*Stream),
}
go m.loop()
return m
}
// Close closes the underlying io.ReadWriteCloser. This will also close
// all streams that are open.
func (m *MuxConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
// Close all the streams
for _, w := range m.streams {
w.Close()
}
m.streams = make(map[uint32]*Stream)
return m.rwc.Close()
}
// Accept accepts a multiplexed connection with the given ID. This
// will block until a request is made to connect.
func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
if stream.state == streamStateSynRecv {
// Fast track establishing since we already got the syn
stream.setState(streamStateEstablished)
stream.mu.Unlock()
}
if stream.state != streamStateEstablished {
// Go into the listening state
stream.setState(streamStateListen)
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
stream.registerStateListener(stateCh)
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
// Wait for the connection to establish
ACCEPT_ESTABLISH_LOOP:
for {
state := <-stateCh
switch state {
case streamStateListen:
case streamStateEstablished:
break ACCEPT_ESTABLISH_LOOP
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
}
}
// Send the ack down
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
return nil, err
}
return stream, nil
}
// Dial opens a connection to the remote end using the given stream ID.
// An Accept on the remote end will only work with if the IDs match.
func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
stream, err := m.openStream(id)
if err != nil {
return nil, err
}
// If the stream isn't closed, then it is already open somehow
stream.mu.Lock()
if stream.state != streamStateClosed {
stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
}
// Open a connection
if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil {
return nil, err
}
stream.setState(streamStateSynSent)
// Register a state change listener to wait for changes
stateCh := make(chan streamState, 10)
stream.registerStateListener(stateCh)
defer func() {
stream.mu.Lock()
defer stream.mu.Unlock()
stream.deregisterStateListener(stateCh)
}()
stream.mu.Unlock()
for {
state := <-stateCh
switch state {
case streamStateSynSent:
case streamStateEstablished:
return stream, nil
default:
defer stream.mu.Unlock()
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
}
}
}
// NextId returns the next available stream ID that isn't currently
// taken.
func (m *MuxConn) NextId() uint32 {
m.mu.Lock()
defer m.mu.Unlock()
for {
result := m.curId
m.curId++
if _, ok := m.streams[result]; !ok {
return result
}
}
}
func (m *MuxConn) openStream(id uint32) (*Stream, error) {
// First grab a read-lock if we have the stream already we can
// cheaply return it.
m.mu.RLock()
if stream, ok := m.streams[id]; ok {
m.mu.RUnlock()
return stream, nil
}
// Now acquire a full blown write lock so we can create the stream
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
// We have to check this again because there is a time period
// above where we couldn't lost this lock.
if stream, ok := m.streams[id]; ok {
return stream, nil
}
// Create the stream object and channel where data will be sent to
dataR, dataW := io.Pipe()
writeCh := make(chan []byte, 256)
// Set the data channel so we can write to it.
stream := &Stream{
id: id,
mux: m,
reader: dataR,
writeCh: writeCh,
stateChange: make(map[chan<- streamState]struct{}),
}
stream.setState(streamStateClosed)
// Start the goroutine that will read from the queue and write
// data out.
go func() {
defer dataW.Close()
for {
data := <-writeCh
if data == nil {
// A nil is a tombstone letting us know we're done
// accepting data.
return
}
if _, err := dataW.Write(data); err != nil {
return
}
}
}()
m.streams[id] = stream
return m.streams[id], nil
}
func (m *MuxConn) loop() {
defer func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, w := range m.streams {
w.mu.Lock()
w.remoteClose()
w.mu.Unlock()
}
}()
var id uint32
var packetType muxPacketType
var length int32
for {
if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
log.Printf("[ERR] Error reading stream ID: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
log.Printf("[ERR] Error reading packet type: %s", err)
return
}
if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
log.Printf("[ERR] Error reading length: %s", err)
return
}
// TODO(mitchellh): probably would be better to re-use a buffer...
data := make([]byte, length)
if length > 0 {
if _, err := m.rwc.Read(data); err != nil {
log.Printf("[ERR] Error reading data: %s", err)
return
}
}
stream, err := m.openStream(id)
if err != nil {
log.Printf("[ERR] Error opening stream %d: %s", id, err)
return
}
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
switch packetType {
case muxPacketAck:
stream.mu.Lock()
switch stream.state {
case streamStateSynSent:
stream.setState(streamStateEstablished)
case streamStateFinWait1:
stream.setState(streamStateFinWait2)
default:
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketSyn:
stream.mu.Lock()
switch stream.state {
case streamStateClosed:
stream.setState(streamStateSynRecv)
case streamStateListen:
stream.setState(streamStateEstablished)
default:
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
case muxPacketFin:
stream.mu.Lock()
switch stream.state {
case streamStateEstablished:
stream.setState(streamStateCloseWait)
m.write(id, muxPacketAck, nil)
// Close the writer on our end since we won't receive any
// more data.
stream.writeCh <- nil
case streamStateFinWait1:
fallthrough
case streamStateFinWait2:
stream.remoteClose()
// Remove this stream from being active so that it
// can be re-used
m.mu.Lock()
delete(m.streams, stream.id)
m.mu.Unlock()
default:
log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
}
stream.mu.Unlock()
case muxPacketData:
stream.mu.Lock()
if stream.state == streamStateEstablished {
select {
case stream.writeCh <- data:
default:
panic(fmt.Sprintf("Failed to write data, buffer full for stream %d", id))
}
} else {
log.Printf("[ERR] Data received for stream in state: %d", stream.state)
}
stream.mu.Unlock()
}
}
}
func (m *MuxConn) write(id uint32, dataType muxPacketType, p []byte) (int, error) {
m.wlock.Lock()
defer m.wlock.Unlock()
if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
return 0, err
}
if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
return 0, err
}
if len(p) == 0 {
return 0, nil
}
return m.rwc.Write(p)
}
// Stream is a single stream of data and implements io.ReadWriteCloser
type Stream struct {
id uint32
mux *MuxConn
reader io.Reader
state streamState
stateChange map[chan<- streamState]struct{}
stateUpdated time.Time
mu sync.Mutex
writeCh chan<- []byte
}
type streamState byte
const (
streamStateClosed streamState = iota
streamStateListen
streamStateSynRecv
streamStateSynSent
streamStateEstablished
streamStateFinWait1
streamStateFinWait2
streamStateCloseWait
)
func (s *Stream) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.state != streamStateEstablished && s.state != streamStateCloseWait {
return fmt.Errorf("Stream in bad state: %d", s.state)
}
if s.state == streamStateEstablished {
s.setState(streamStateFinWait1)
} else {
s.remoteClose()
}
s.mux.write(s.id, muxPacketFin, nil)
return nil
}
func (s *Stream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *Stream) Write(p []byte) (int, error) {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state != streamStateEstablished {
return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
}
return s.mux.write(s.id, muxPacketData, p)
}
func (s *Stream) remoteClose() {
s.setState(streamStateClosed)
s.writeCh <- nil
}
func (s *Stream) registerStateListener(ch chan<- streamState) {
s.stateChange[ch] = struct{}{}
}
func (s *Stream) deregisterStateListener(ch chan<- streamState) {
delete(s.stateChange, ch)
}
func (s *Stream) setState(state streamState) {
s.state = state
s.stateUpdated = time.Now().UTC()
for ch, _ := range s.stateChange {
select {
case ch <- state:
default:
}
}
}
package rpc
import (
"io"
"net"
"sync"
"testing"
)
func readStream(t *testing.T, s io.Reader) string {
var data [1024]byte
n, err := s.Read(data[:])
if err != nil {
t.Fatalf("err: %s", err)
}
return string(data[0:n])
}
func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err: %s", err)
}
// Server side
doneCh := make(chan struct{})
go func() {
defer close(doneCh)
conn, err := l.Accept()
l.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
server = NewMuxConn(conn)
}()
// Client side
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("err: %s", err)
}
client = NewMuxConn(conn)
// Wait for the server
<-doneCh
return
}
func TestMuxConn(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
// When the server is done
doneCh := make(chan struct{})
// The server side
go func() {
defer close(doneCh)
s0, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := server.Dial(1)
if err != nil {
t.Fatalf("err: %s", err)
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
data := readStream(t, s1)
if data != "another" {
t.Fatalf("bad: %#v", data)
}
}()
go func() {
defer wg.Done()
data := readStream(t, s0)
if data != "hello" {
t.Fatalf("bad: %#v", data)
}
}()
wg.Wait()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
s1, err := client.Accept(1)
if err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s0.Write([]byte("hello")); err != nil {
t.Fatalf("err: %s", err)
}
if _, err := s1.Write([]byte("another")); err != nil {
t.Fatalf("err: %s", err)
}
// Wait for the server to be done
<-doneCh
}
func TestMuxConn_socketClose(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
_, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
server.rwc.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_clientClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go func() {
conn, err := server.Accept(0)
if err != nil {
t.Fatalf("err: %s", err)
}
conn.Close()
}()
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConn_serverClosesStreams(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
go server.Accept(0)
s0, err := client.Dial(0)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := server.Close(); err != nil {
t.Fatalf("err: %s", err)
}
// This should block forever since we never write onto this stream.
var data [1024]byte
_, err = s0.Read(data[:])
if err != io.EOF {
t.Fatalf("err: %s", err)
}
}
func TestMuxConnNextId(t *testing.T) {
client, server := testMux(t)
defer client.Close()
defer server.Close()
a := client.NextId()
b := client.NextId()
if a != 0 || b != 1 {
t.Fatalf("IDs should increment")
}
}
......@@ -9,11 +9,14 @@ import (
// executed over an RPC connection.
type postProcessor struct {
client *rpc.Client
mux *MuxConn
}
// PostProcessorServer wraps a packer.PostProcessor implementation and makes it
// exportable as part of a Golang RPC server.
type PostProcessorServer struct {
client *rpc.Client
mux *MuxConn
p packer.PostProcessor
}
......@@ -24,12 +27,9 @@ type PostProcessorConfigureArgs struct {
type PostProcessorProcessResponse struct {
Err error
Keep bool
RPCAddress string
StreamId uint32
}
func PostProcessor(client *rpc.Client) *postProcessor {
return &postProcessor{client}
}
func (p *postProcessor) Configure(raw ...interface{}) (err error) {
args := &PostProcessorConfigureArgs{Configs: raw}
if cerr := p.client.Call("PostProcessor.Configure", args, &err); cerr != nil {
......@@ -40,12 +40,14 @@ func (p *postProcessor) Configure(raw ...interface{}) (err error) {
}
func (p *postProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Artifact, bool, error) {
server := rpc.NewServer()
RegisterArtifact(server, a)
RegisterUi(server, ui)
nextId := p.mux.NextId()
server := NewServerWithMux(p.mux, nextId)
server.RegisterArtifact(a)
server.RegisterUi(ui)
go server.Serve()
var response PostProcessorProcessResponse
if err := p.client.Call("PostProcessor.PostProcess", serveSingleConn(server), &response); err != nil {
if err := p.client.Call("PostProcessor.PostProcess", nextId, &response); err != nil {
return nil, false, err
}
......@@ -53,16 +55,16 @@ func (p *postProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Art
return nil, false, response.Err
}
if response.RPCAddress == "" {
if response.StreamId == 0 {
return nil, false, nil
}
client, err := rpcDial(response.RPCAddress)
client, err := NewClientWithMux(p.mux, response.StreamId)
if err != nil {
return nil, false, err
}
return Artifact(client), response.Keep, nil
return client.Artifact(), response.Keep, nil
}
func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply *error) error {
......@@ -74,19 +76,20 @@ func (p *PostProcessorServer) Configure(args *PostProcessorConfigureArgs, reply
return nil
}
func (p *PostProcessorServer) PostProcess(address string, reply *PostProcessorProcessResponse) error {
client, err := rpcDial(address)
func (p *PostProcessorServer) PostProcess(streamId uint32, reply *PostProcessorProcessResponse) error {
client, err := NewClientWithMux(p.mux, streamId)
if err != nil {
return err
return NewBasicError(err)
}
responseAddress := ""
artifact, keep, err := p.p.PostProcess(&Ui{client}, Artifact(client))
if err == nil && artifact != nil {
server := rpc.NewServer()
RegisterArtifact(server, artifact)
responseAddress = serveSingleConn(server)
defer client.Close()
streamId = 0
artifactResult, keep, err := p.p.PostProcess(client.Ui(), client.Artifact())
if err == nil && artifactResult != nil {
streamId = p.mux.NextId()
server := NewServerWithMux(p.mux, streamId)
server.RegisterArtifact(artifactResult)
go server.Serve()
}
if err != nil {
......@@ -96,7 +99,7 @@ func (p *PostProcessorServer) PostProcess(address string, reply *PostProcessorPr
*reply = PostProcessorProcessResponse{
Err: err,
Keep: keep,
RPCAddress: responseAddress,
StreamId: streamId,
}
return nil
......
......@@ -2,18 +2,18 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
var testPostProcessorArtifact = new(testArtifact)
var testPostProcessorArtifact = new(packer.MockArtifact)
type TestPostProcessor struct {
configCalled bool
configVal []interface{}
ppCalled bool
ppArtifact packer.Artifact
ppArtifactId string
ppUi packer.Ui
}
......@@ -26,6 +26,7 @@ func (pp *TestPostProcessor) Configure(v ...interface{}) error {
func (pp *TestPostProcessor) PostProcess(ui packer.Ui, a packer.Artifact) (packer.Artifact, bool, error) {
pp.ppCalled = true
pp.ppArtifact = a
pp.ppArtifactId = a.Id()
pp.ppUi = ui
return testPostProcessorArtifact, false, nil
}
......@@ -35,20 +36,16 @@ func TestPostProcessorRPC(t *testing.T) {
p := new(TestPostProcessor)
// Start the server
server := rpc.NewServer()
RegisterPostProcessor(server, p)
address := serveSingleConn(server)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterPostProcessor(p)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("Error connecting to rpc: %s", err)
}
ppClient := client.PostProcessor()
// Test Configure
config := 42
pClient := PostProcessor(client)
err = pClient.Configure(config)
err := ppClient.Configure(config)
if err != nil {
t.Fatalf("error: %s", err)
}
......@@ -62,9 +59,11 @@ func TestPostProcessorRPC(t *testing.T) {
}
// Test PostProcess
a := new(testArtifact)
a := &packer.MockArtifact{
IdValue: "ppTestId",
}
ui := new(testUi)
artifact, _, err := pClient.PostProcess(ui, a)
artifact, _, err := ppClient.PostProcess(ui, a)
if err != nil {
t.Fatalf("err: %s", err)
}
......@@ -73,18 +72,18 @@ func TestPostProcessorRPC(t *testing.T) {
t.Fatal("postprocess should be called")
}
if p.ppArtifact.BuilderId() != "bid" {
t.Fatal("unknown artifact")
if p.ppArtifactId != "ppTestId" {
t.Fatalf("unknown artifact: %s", p.ppArtifact.Id())
}
if artifact.BuilderId() != "bid" {
t.Fatal("unknown result artifact")
if artifact.Id() != "id" {
t.Fatalf("unknown artifact: %s", artifact.Id())
}
}
func TestPostProcessor_Implements(t *testing.T) {
var raw interface{}
raw = PostProcessor(nil)
raw = new(postProcessor)
if _, ok := raw.(packer.PostProcessor); !ok {
t.Fatal("not a postprocessor")
}
......
......@@ -10,25 +10,20 @@ import (
// executed over an RPC connection.
type provisioner struct {
client *rpc.Client
mux *MuxConn
}
// ProvisionerServer wraps a packer.Provisioner implementation and makes it
// exportable as part of a Golang RPC server.
type ProvisionerServer struct {
p packer.Provisioner
mux *MuxConn
}
type ProvisionerPrepareArgs struct {
Configs []interface{}
}
type ProvisionerProvisionArgs struct {
RPCAddress string
}
func Provisioner(client *rpc.Client) *provisioner {
return &provisioner{client}
}
func (p *provisioner) Prepare(configs ...interface{}) (err error) {
args := &ProvisionerPrepareArgs{configs}
if cerr := p.client.Call("Provisioner.Prepare", args, &err); cerr != nil {
......@@ -39,13 +34,13 @@ func (p *provisioner) Prepare(configs ...interface{}) (err error) {
}
func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
// TODO: Error handling
server := rpc.NewServer()
RegisterCommunicator(server, comm)
RegisterUi(server, ui)
nextId := p.mux.NextId()
server := NewServerWithMux(p.mux, nextId)
server.RegisterCommunicator(comm)
server.RegisterUi(ui)
go server.Serve()
args := &ProvisionerProvisionArgs{serveSingleConn(server)}
return p.client.Call("Provisioner.Provision", args, new(interface{}))
return p.client.Call("Provisioner.Provision", nextId, new(interface{}))
}
func (p *provisioner) Cancel() {
......@@ -64,16 +59,14 @@ func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *error)
return nil
}
func (p *ProvisionerServer) Provision(args *ProvisionerProvisionArgs, reply *interface{}) error {
client, err := rpcDial(args.RPCAddress)
func (p *ProvisionerServer) Provision(streamId uint32, reply *interface{}) error {
client, err := NewClientWithMux(p.mux, streamId)
if err != nil {
return err
return NewBasicError(err)
}
defer client.Close()
comm := Communicator(client)
ui := &Ui{client}
if err := p.p.Provision(ui, comm); err != nil {
if err := p.p.Provision(client.Ui(), client.Communicator()); err != nil {
return NewBasicError(err)
}
......
......@@ -2,7 +2,6 @@ package rpc
import (
"github.com/mitchellh/packer/packer"
"net/rpc"
"reflect"
"testing"
)
......@@ -12,19 +11,14 @@ func TestProvisionerRPC(t *testing.T) {
p := new(packer.MockProvisioner)
// Start the server
server := rpc.NewServer()
RegisterProvisioner(server, p)
address := serveSingleConn(server)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
t.Fatalf("err: %s", err)
}
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterProvisioner(p)
pClient := client.Provisioner()
// Test Prepare
config := 42
pClient := Provisioner(client)
pClient.Prepare(config)
if !p.PrepCalled {
t.Fatal("should be called")
......@@ -41,11 +35,6 @@ func TestProvisionerRPC(t *testing.T) {
t.Fatal("should be called")
}
p.ProvUi.Say("foo")
if !ui.sayCalled {
t.Fatal("should be called")
}
// Test Cancel
pClient.Cancel()
if !p.CancelCalled {
......@@ -54,5 +43,5 @@ func TestProvisionerRPC(t *testing.T) {
}
func TestProvisioner_Implements(t *testing.T) {
var _ packer.Provisioner = Provisioner(nil)
var _ packer.Provisioner = new(provisioner)
}
package rpc
import (
"fmt"
"github.com/mitchellh/packer/packer"
"io"
"log"
"net/rpc"
"sync/atomic"
)
// Registers the appropriate endpoint on an RPC server to serve an
// Artifact.
func RegisterArtifact(s *rpc.Server, a packer.Artifact) {
s.RegisterName("Artifact", &ArtifactServer{a})
var endpointId uint64
const (
DefaultArtifactEndpoint string = "Artifact"
DefaultBuildEndpoint = "Build"
DefaultBuilderEndpoint = "Builder"
DefaultCacheEndpoint = "Cache"
DefaultCommandEndpoint = "Command"
DefaultCommunicatorEndpoint = "Communicator"
DefaultEnvironmentEndpoint = "Environment"
DefaultHookEndpoint = "Hook"
DefaultPostProcessorEndpoint = "PostProcessor"
DefaultProvisionerEndpoint = "Provisioner"
DefaultUiEndpoint = "Ui"
)
// Server represents an RPC server for Packer. This must be paired on
// the other side with a Client.
type Server struct {
mux *MuxConn
streamId uint32
server *rpc.Server
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Build.
func RegisterBuild(s *rpc.Server, b packer.Build) {
s.RegisterName("Build", &BuildServer{b})
// NewServer returns a new Packer RPC server.
func NewServer(conn io.ReadWriteCloser) *Server {
return NewServerWithMux(NewMuxConn(conn), 0)
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Builder.
func RegisterBuilder(s *rpc.Server, b packer.Builder) {
s.RegisterName("Builder", &BuilderServer{b})
func NewServerWithMux(mux *MuxConn, streamId uint32) *Server {
return &Server{
mux: mux,
streamId: streamId,
server: rpc.NewServer(),
}
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Cache.
func RegisterCache(s *rpc.Server, c packer.Cache) {
s.RegisterName("Cache", &CacheServer{c})
func (s *Server) Close() error {
return s.mux.Close()
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Command.
func RegisterCommand(s *rpc.Server, c packer.Command) {
s.RegisterName("Command", &CommandServer{c})
func (s *Server) RegisterArtifact(a packer.Artifact) {
s.server.RegisterName(DefaultArtifactEndpoint, &ArtifactServer{
artifact: a,
})
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Communicator.
func RegisterCommunicator(s *rpc.Server, c packer.Communicator) {
s.RegisterName("Communicator", &CommunicatorServer{c})
func (s *Server) RegisterBuild(b packer.Build) {
s.server.RegisterName(DefaultBuildEndpoint, &BuildServer{
build: b,
mux: s.mux,
})
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer Environment
func RegisterEnvironment(s *rpc.Server, e packer.Environment) {
s.RegisterName("Environment", &EnvironmentServer{e})
func (s *Server) RegisterBuilder(b packer.Builder) {
s.server.RegisterName(DefaultBuilderEndpoint, &BuilderServer{
builder: b,
mux: s.mux,
})
}
// Registers the appropriate endpoint on an RPC server to serve a
// Hook.
func RegisterHook(s *rpc.Server, hook packer.Hook) {
s.RegisterName("Hook", &HookServer{hook})
func (s *Server) RegisterCache(c packer.Cache) {
s.server.RegisterName(DefaultCacheEndpoint, &CacheServer{
cache: c,
})
}
// Registers the appropriate endpoing on an RPC server to serve a
// PostProcessor.
func RegisterPostProcessor(s *rpc.Server, p packer.PostProcessor) {
s.RegisterName("PostProcessor", &PostProcessorServer{p})
func (s *Server) RegisterCommand(c packer.Command) {
s.server.RegisterName(DefaultCommandEndpoint, &CommandServer{
command: c,
mux: s.mux,
})
}
// Registers the appropriate endpoint on an RPC server to serve a packer.Provisioner
func RegisterProvisioner(s *rpc.Server, p packer.Provisioner) {
s.RegisterName("Provisioner", &ProvisionerServer{p})
func (s *Server) RegisterCommunicator(c packer.Communicator) {
s.server.RegisterName(DefaultCommunicatorEndpoint, &CommunicatorServer{
c: c,
mux: s.mux,
})
}
// Registers the appropriate endpoint on an RPC server to serve a
// Packer UI
func RegisterUi(s *rpc.Server, ui packer.Ui) {
s.RegisterName("Ui", &UiServer{ui})
func (s *Server) RegisterEnvironment(b packer.Environment) {
s.server.RegisterName(DefaultEnvironmentEndpoint, &EnvironmentServer{
env: b,
mux: s.mux,
})
}
func serveSingleConn(s *rpc.Server) string {
l := netListenerInRange(portRangeMin, portRangeMax)
func (s *Server) RegisterHook(h packer.Hook) {
s.server.RegisterName(DefaultHookEndpoint, &HookServer{
hook: h,
mux: s.mux,
})
}
// Accept a single connection in a goroutine and then exit
go func() {
defer l.Close()
conn, err := l.Accept()
func (s *Server) RegisterPostProcessor(p packer.PostProcessor) {
s.server.RegisterName(DefaultPostProcessorEndpoint, &PostProcessorServer{
mux: s.mux,
p: p,
})
}
func (s *Server) RegisterProvisioner(p packer.Provisioner) {
s.server.RegisterName(DefaultProvisionerEndpoint, &ProvisionerServer{
mux: s.mux,
p: p,
})
}
func (s *Server) RegisterUi(ui packer.Ui) {
s.server.RegisterName(DefaultUiEndpoint, &UiServer{
ui: ui,
})
}
// ServeConn serves a single connection over the RPC server. It is up
// to the caller to obtain a proper io.ReadWriteCloser.
func (s *Server) Serve() {
// Accept a connection on stream ID 0, which is always used for
// normal client to server connections.
stream, err := s.mux.Accept(s.streamId)
defer stream.Close()
if err != nil {
panic(err)
log.Printf("[ERR] Error retrieving stream for serving: %s", err)
return
}
s.ServeConn(conn)
}()
s.server.ServeConn(stream)
}
// registerComponent registers a single Packer RPC component onto
// the RPC server. If id is true, then a unique ID number will be appended
// onto the end of the endpoint.
//
// The endpoint name is returned.
func registerComponent(server *rpc.Server, name string, rcvr interface{}, id bool) string {
endpoint := name
if id {
fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&endpointId, 1))
}
return l.Addr().String()
server.RegisterName(endpoint, rcvr)
return endpoint
}
......@@ -10,6 +10,7 @@ import (
// over an RPC connection.
type Ui struct {
client *rpc.Client
endpoint string
}
// UiServer wraps a packer.Ui implementation and makes it exportable
......
package rpc
import (
"net/rpc"
"reflect"
"testing"
)
......@@ -52,17 +51,12 @@ func TestUiRPC(t *testing.T) {
ui := new(testUi)
// Start the RPC server
server := rpc.NewServer()
RegisterUi(server, ui)
address := serveSingleConn(server)
client, server := testClientServer(t)
defer client.Close()
defer server.Close()
server.RegisterUi(ui)
// Create the client over RPC and run some methods to verify it works
client, err := rpc.Dial("tcp", address)
if err != nil {
panic(err)
}
uiClient := &Ui{client}
uiClient := client.Ui()
// Basic error and say tests
result, err := uiClient.Ask("query")
......
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(chroot.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(chroot.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(ebs.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(ebs.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(instance.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(instance.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(digitalocean.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(digitalocean.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(docker.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(docker.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(openstack.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(openstack.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(qemu.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(qemu.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(virtualbox.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(virtualbox.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeBuilder(new(vmware.Builder))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterBuilder(new(vmware.Builder))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeCommand(new(build.Command))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterCommand(new(build.Command))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeCommand(new(fix.Command))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterCommand(new(fix.Command))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeCommand(new(inspect.Command))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterCommand(new(inspect.Command))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeCommand(new(validate.Command))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterCommand(new(validate.Command))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServePostProcessor(new(vagrant.PostProcessor))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterPostProcessor(new(vagrant.PostProcessor))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServePostProcessor(new(vsphere.PostProcessor))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterPostProcessor(new(vsphere.PostProcessor))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(ansiblelocal.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(ansiblelocal.Provisioner))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(chefsolo.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(chefsolo.Provisioner))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(file.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(file.Provisioner))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(puppetmasterless.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(puppetmasterless.Provisioner))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(saltmasterless.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(saltmasterless.Provisioner))
server.Serve()
}
......@@ -6,5 +6,10 @@ import (
)
func main() {
plugin.ServeProvisioner(new(shell.Provisioner))
server, err := plugin.Server()
if err != nil {
panic(err)
}
server.RegisterProvisioner(new(shell.Provisioner))
server.Serve()
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment