Commit 03850caf authored by Chris Bednarski's avatar Chris Bednarski

Implemented timeout around the SSH handshake, including a unit test

parent 6ca48fa3
...@@ -5,9 +5,6 @@ import ( ...@@ -5,9 +5,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -16,8 +13,15 @@ import ( ...@@ -16,8 +13,15 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync" "sync"
"time"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
) )
var ErrHandshakeTimeout = fmt.Errorf("Timeout during SSH handshake")
type comm struct { type comm struct {
client *ssh.Client client *ssh.Client
config *Config config *Config
...@@ -40,6 +44,10 @@ type Config struct { ...@@ -40,6 +44,10 @@ type Config struct {
// DisableAgent, if true, will not forward the SSH agent. // DisableAgent, if true, will not forward the SSH agent.
DisableAgent bool DisableAgent bool
// HandshakeTimeout limits the amount of time we'll wait to handshake before
// saying the connection failed.
HandshakeTimeout time.Duration
} }
// Creates a new packer.Communicator implementation over SSH. This takes // Creates a new packer.Communicator implementation over SSH. This takes
...@@ -273,9 +281,39 @@ func (c *comm) reconnect() (err error) { ...@@ -273,9 +281,39 @@ func (c *comm) reconnect() (err error) {
} }
log.Printf("handshaking with SSH") log.Printf("handshaking with SSH")
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
// Default timeout to 1 minute if it wasn't specified (zero value). For
// when you need to handshake from low orbit.
var duration time.Duration
if c.config.HandshakeTimeout == 0 {
duration = 1 * time.Minute
} else {
duration = c.config.HandshakeTimeout
}
timeoutExceeded := time.After(duration)
connectionEstablished := make(chan bool, 1)
var sshConn ssh.Conn
var sshChan <-chan ssh.NewChannel
var req <-chan *ssh.Request
go func() {
sshConn, sshChan, req, err = ssh.NewClientConn(c.conn, c.address, c.config.SSHConfig)
connectionEstablished <- true
}()
select {
case <-connectionEstablished:
// We don't need to do anything here. We just want select to block until
// we connect or timeout.
case <-timeoutExceeded:
return ErrHandshakeTimeout
}
if err != nil { if err != nil {
log.Printf("handshake error: %s", err) log.Printf("handshake error: %s", err)
return
} }
log.Printf("handshake complete!") log.Printf("handshake complete!")
if sshConn != nil { if sshConn != nil {
......
...@@ -5,10 +5,12 @@ package ssh ...@@ -5,10 +5,12 @@ package ssh
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
"net" "net"
"testing" "testing"
"time"
"github.com/mitchellh/packer/packer"
"golang.org/x/crypto/ssh"
) )
// private key for mock server // private key for mock server
...@@ -94,6 +96,28 @@ func newMockLineServer(t *testing.T) string { ...@@ -94,6 +96,28 @@ func newMockLineServer(t *testing.T) string {
return l.Addr().String() return l.Addr().String()
} }
func newMockBrokenServer(t *testing.T) string {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable tp listen for connection: %s", err)
}
go func() {
defer l.Close()
c, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept incoming connection: %s", err)
}
defer c.Close()
// This should block for a period of time longer than our timeout in
// the test case. That way we invoke a failure scenario.
time.Sleep(5 * time.Second)
t.Log("Block on handshaking for SSH connection")
}()
return l.Addr().String()
}
func TestCommIsCommunicator(t *testing.T) { func TestCommIsCommunicator(t *testing.T) {
var raw interface{} var raw interface{}
raw = &comm{} raw = &comm{}
...@@ -157,10 +181,44 @@ func TestStart(t *testing.T) { ...@@ -157,10 +181,44 @@ func TestStart(t *testing.T) {
t.Fatalf("error connecting to SSH: %s", err) t.Fatalf("error connecting to SSH: %s", err)
} }
var cmd packer.RemoteCmd cmd := &packer.RemoteCmd{
stdout := new(bytes.Buffer) Command: "echo foo",
cmd.Command = "echo foo" Stdout: new(bytes.Buffer),
cmd.Stdout = stdout }
client.Start(cmd)
}
func TestHandshakeTimeout(t *testing.T) {
clientConfig := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
ssh.Password("pass"),
},
}
client.Start(&cmd) address := newMockBrokenServer(t)
conn := func() (net.Conn, error) {
conn, err := net.Dial("tcp", address)
if err != nil {
t.Fatalf("unable to dial to remote side: %s", err)
}
return conn, err
}
config := &Config{
Connection: conn,
SSHConfig: clientConfig,
HandshakeTimeout: 50 * time.Millisecond,
}
_, err := New(address, config)
if err != ErrHandshakeTimeout {
// Note: there's another error that can come back from this call:
// ssh: handshake failed: EOF
// This should appear in cases where the handshake fails because of
// malformed (or no) data sent back by the server, but should not happen
// in a timeout scenario.
t.Fatalf("Expected handshake timeout, got: %s", err)
}
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment