Commit 4e699390 authored by Jack Pearkes's avatar Jack Pearkes

builder/digitalocean: builder config tests and create_ssh_key step

parent 8599af62
...@@ -16,7 +16,7 @@ const DIGITALOCEAN_API_URL = "https://api.digitalocean.com" ...@@ -16,7 +16,7 @@ const DIGITALOCEAN_API_URL = "https://api.digitalocean.com"
type DigitalOceanClient struct { type DigitalOceanClient struct {
// The http client for communicating // The http client for communicating
client *http.client client *http.Client
// The base URL of the API // The base URL of the API
BaseURL string BaseURL string
...@@ -27,7 +27,7 @@ type DigitalOceanClient struct { ...@@ -27,7 +27,7 @@ type DigitalOceanClient struct {
} }
// Creates a new client for communicating with DO // Creates a new client for communicating with DO
func (d DigitalOceanClient) New(client string, key string) *Client { func (d DigitalOceanClient) New(client string, key string) *DigitalOceanClient {
c := &DigitalOceanClient{ c := &DigitalOceanClient{
client: http.DefaultClient, client: http.DefaultClient,
BaseURL: DIGITALOCEAN_API_URL, BaseURL: DIGITALOCEAN_API_URL,
...@@ -43,7 +43,7 @@ func (d DigitalOceanClient) CreateKey(name string, pub string) (uint, error) { ...@@ -43,7 +43,7 @@ func (d DigitalOceanClient) CreateKey(name string, pub string) (uint, error) {
body, err := NewRequest(d, "ssh_keys/new", params) body, err := NewRequest(d, "ssh_keys/new", params)
if err != nil { if err != nil {
return nil, err return 0, err
} }
// Read the SSH key's ID we just created // Read the SSH key's ID we just created
...@@ -67,13 +67,13 @@ func (d DigitalOceanClient) CreateDroplet(name string, size uint, image uint, re ...@@ -67,13 +67,13 @@ func (d DigitalOceanClient) CreateDroplet(name string, size uint, image uint, re
body, err := NewRequest(d, "droplets/new", params) body, err := NewRequest(d, "droplets/new", params)
if err != nil { if err != nil {
return nil, err return 0, err
} }
// Read the Droplets ID // Read the Droplets ID
droplet := body["droplet"].(map[string]interface{}) droplet := body["droplet"].(map[string]interface{})
dropletId := droplet["id"].(float64) dropletId := droplet["id"].(float64)
return dropletId, err return uint(dropletId), err
} }
// Destroys a droplet // Destroys a droplet
...@@ -84,7 +84,7 @@ func (d DigitalOceanClient) DestroyDroplet(id uint) error { ...@@ -84,7 +84,7 @@ func (d DigitalOceanClient) DestroyDroplet(id uint) error {
} }
// Powers off a droplet // Powers off a droplet
func (d DigitalOceanClient) PowerOffDroplet(name string, pub string) error { func (d DigitalOceanClient) PowerOffDroplet(id uint) error {
path := fmt.Sprintf("droplets/%s/power_off", id) path := fmt.Sprintf("droplets/%s/power_off", id)
_, err := NewRequest(d, path, "") _, err := NewRequest(d, path, "")
...@@ -108,7 +108,7 @@ func (d DigitalOceanClient) DropletStatus(id uint) (string, error) { ...@@ -108,7 +108,7 @@ func (d DigitalOceanClient) DropletStatus(id uint) (string, error) {
body, err := NewRequest(d, path, "") body, err := NewRequest(d, path, "")
if err != nil { if err != nil {
return nil, err return "", err
} }
// Read the droplet's "status" // Read the droplet's "status"
...@@ -125,37 +125,37 @@ func NewRequest(d DigitalOceanClient, path string, params string) (map[string]in ...@@ -125,37 +125,37 @@ func NewRequest(d DigitalOceanClient, path string, params string) (map[string]in
url := fmt.Sprintf("%s/%s?%s&client_id=%s&api_key=%s", url := fmt.Sprintf("%s/%s?%s&client_id=%s&api_key=%s",
DIGITALOCEAN_API_URL, path, params, d.ClientID, d.APIKey) DIGITALOCEAN_API_URL, path, params, d.ClientID, d.APIKey)
var decodedResponse map[string]interface{}
resp, err := client.Get(url) resp, err := client.Get(url)
if err != nil { if err != nil {
return nil, err return decodedResponse, err
} }
body, err = ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
return nil, err return decodedResponse, err
} }
// Catch all non-200 status and return an error // Catch all non-200 status and return an error
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
err = errors.New("recieved non-200 status from digitalocean: %d", resp.StatusCode) err = errors.New(fmt.Sprintf("recieved non-200 status from digitalocean: %d", resp.StatusCode))
return nil, err return decodedResponse, err
} }
var decodedResponse map[string]interface{}
err = json.Unmarshal(body, &decodedResponse) err = json.Unmarshal(body, &decodedResponse)
if err != nil { if err != nil {
return nil, err return decodedResponse, err
} }
// Catch all non-OK statuses from DO and return an error // Catch all non-OK statuses from DO and return an error
status := decodedResponse["status"] status := decodedResponse["status"]
if status != "OK" { if status != "OK" {
err = errors.New("recieved non-OK status from digitalocean: %d", status) err = errors.New(fmt.Sprintf("recieved non-OK status from digitalocean: %d", status))
return nil, err return decodedResponse, err
} }
return decodedResponse, nil return decodedResponse, nil
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
package digitalocean package digitalocean
import ( import (
"errors"
"fmt"
"github.com/mitchellh/mapstructure"
"github.com/mitchellh/multistep" "github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer" "github.com/mitchellh/packer/packer"
"log" "log"
...@@ -31,6 +34,7 @@ type config struct { ...@@ -31,6 +34,7 @@ type config struct {
SnapshotName string `mapstructure:"snapshot_name"` SnapshotName string `mapstructure:"snapshot_name"`
RawSSHTimeout string `mapstructure:"ssh_timeout"` RawSSHTimeout string `mapstructure:"ssh_timeout"`
SSHTimeout time.Duration
} }
type Builder struct { type Builder struct {
...@@ -86,18 +90,18 @@ func (b *Builder) Prepare(raw interface{}) error { ...@@ -86,18 +90,18 @@ func (b *Builder) Prepare(raw interface{}) error {
// Required configurations that will display errors if not set // Required configurations that will display errors if not set
// //
if b.config.ClientId == "" { if b.config.ClientID == "" {
errs = append(errs, errors.New("a client_id must be specified")) errs = append(errs, errors.New("a client_id must be specified"))
} }
if b.config.APIKey == "" { if b.config.APIKey == "" {
errs = append(errs, errors.New("an api_key must be specified")) errs = append(errs, errors.New("an api_key must be specified"))
} }
timeout, err := time.ParseDuration(b.config.RawSSHTimeout)
b.config.SSHTimeout, err = time.ParseDuration(b.config.RawSSHTimeout)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("Failed parsing ssh_timeout: %s", err)) errs = append(errs, fmt.Errorf("Failed parsing ssh_timeout: %s", err))
} }
b.config.SSHTimeout = timeout
if len(errs) > 0 { if len(errs) > 0 {
return &packer.MultiError{errs} return &packer.MultiError{errs}
...@@ -108,5 +112,38 @@ func (b *Builder) Prepare(raw interface{}) error { ...@@ -108,5 +112,38 @@ func (b *Builder) Prepare(raw interface{}) error {
} }
func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) { func (b *Builder) Run(ui packer.Ui, hook packer.Hook, cache packer.Cache) (packer.Artifact, error) {
// Initialize the DO API client
client := DigitalOceanClient{}.New(b.config.ClientID, b.config.APIKey)
// Set up the state
state := make(map[string]interface{})
state["config"] = b.config
state["client"] = client
state["hook"] = hook
state["ui"] = ui
// Build the steps
steps := []multistep.Step{
new(stepCreateSSHKey),
new(stepCreateDroplet),
new(stepConnectSSH),
new(stepProvision),
new(stepPowerOff),
new(stepSnapshot),
new(stepDestroyDroplet),
new(stepDestroySSHKey),
}
// Run the steps
b.runner = &multistep.BasicRunner{Steps: steps}
b.runner.Run(state)
return nil, nil
}
func (b *Builder) Cancel() {
if b.runner != nil {
log.Println("Cancelling the step runner...")
b.runner.Cancel()
}
} }
package digitalocean
import (
"github.com/mitchellh/packer/packer"
"testing"
)
func testConfig() map[string]interface{} {
return map[string]interface{}{
"client_id": "foo",
"api_key": "bar",
}
}
func TestBuilder_ImplementsBuilder(t *testing.T) {
var raw interface{}
raw = &Builder{}
if _, ok := raw.(packer.Builder); !ok {
t.Fatalf("Builder should be a builder")
}
}
func TestBuilder_Prepare_BadType(t *testing.T) {
b := &Builder{}
c := map[string]interface{}{
"api_key": []string{},
}
err := b.Prepare(c)
if err == nil {
t.Fatalf("prepare should fail")
}
}
func TestBuilderPrepare_APIKey(t *testing.T) {
var b Builder
config := testConfig()
// Test good
config["api_key"] = "foo"
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.APIKey != "foo" {
t.Errorf("access key invalid: %s", b.config.APIKey)
}
// Test bad
delete(config, "api_key")
b = Builder{}
err = b.Prepare(config)
if err == nil {
t.Fatal("should have error")
}
}
func TestBuilderPrepare_ClientID(t *testing.T) {
var b Builder
config := testConfig()
// Test good
config["client_id"] = "foo"
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.ClientID != "foo" {
t.Errorf("invalid: %s", b.config.ClientID)
}
// Test bad
delete(config, "client_id")
b = Builder{}
err = b.Prepare(config)
if err == nil {
t.Fatal("should have error")
}
}
func TestBuilderPrepare_RegionID(t *testing.T) {
var b Builder
config := testConfig()
// Test default
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.RegionID != 1 {
t.Errorf("invalid: %d", b.config.RegionID)
}
// Test set
config["region_id"] = 2
b = Builder{}
err = b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.RegionID != 2 {
t.Errorf("invalid: %d", b.config.RegionID)
}
}
func TestBuilderPrepare_SizeID(t *testing.T) {
var b Builder
config := testConfig()
// Test default
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SizeID != 66 {
t.Errorf("invalid: %d", b.config.SizeID)
}
// Test set
config["size_id"] = 67
b = Builder{}
err = b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SizeID != 67 {
t.Errorf("invalid: %d", b.config.SizeID)
}
}
func TestBuilderPrepare_ImageID(t *testing.T) {
var b Builder
config := testConfig()
// Test default
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SizeID != 2676 {
t.Errorf("invalid: %d", b.config.SizeID)
}
// Test set
config["size_id"] = 2
b = Builder{}
err = b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SizeID != 2 {
t.Errorf("invalid: %d", b.config.SizeID)
}
}
func TestBuilderPrepare_SSHUsername(t *testing.T) {
var b Builder
config := testConfig()
// Test default
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SSHUsername != "root" {
t.Errorf("invalid: %d", b.config.SSHUsername)
}
// Test set
config["ssh_username"] = ""
b = Builder{}
err = b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SSHPort != 35 {
t.Errorf("invalid: %d", b.config.SSHPort)
}
}
func TestBuilderPrepare_SSHTimeout(t *testing.T) {
var b Builder
config := testConfig()
// Test default
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.RawSSHTimeout != "1m" {
t.Errorf("invalid: %d", b.config.RawSSHTimeout)
}
// Test set
config["ssh_timeout"] = "30s"
b = Builder{}
err = b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
// Test bad
config["ssh_timeout"] = "tubes"
b = Builder{}
err = b.Prepare(config)
if err == nil {
t.Fatal("should have error")
}
}
func TestBuilderPrepare_SnapshotName(t *testing.T) {
var b Builder
config := testConfig()
// Test set
config["snapshot_name"] = "foo"
err := b.Prepare(config)
if err != nil {
t.Fatalf("should not have error: %s", err)
}
if b.config.SnapshotName != "foo" {
t.Errorf("invalid: %s", b.config.SnapshotName)
}
}
package digitalocean
import (
"cgl.tideland.biz/identifier"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"github.com/mitchellh/multistep"
"github.com/mitchellh/packer/packer"
"log"
)
type stepCreateSSHKey struct {
keyId uint
}
func (s *stepCreateSSHKey) Run(state map[string]interface{}) multistep.StepAction {
client := state["client"].(*DigitalOceanClient)
ui := state["ui"].(packer.Ui)
ui.Say("Creating temporary ssh key for droplet...")
priv, err := rsa.GenerateKey(rand.Reader, 2014)
if err != nil {
ui.Error(err.Error())
return multistep.ActionHalt
}
// Set the pem formatted private key on the state for later
priv_der := x509.MarshalPKCS1PrivateKey(priv)
priv_blk := pem.Block{
Type: "RSA PRIVATE KEY",
Headers: nil,
Bytes: priv_der,
}
// Create the public key for uploading to DO
pub := priv.PublicKey
pub_der, err := x509.MarshalPKIXPublicKey(&pub)
if err != nil {
ui.Error(err.Error())
return multistep.ActionHalt
}
pub_blk := pem.Block{
Type: "PUBLIC KEY",
Headers: nil,
Bytes: pub_der,
}
pub_pem := string(pem.EncodeToMemory(&pub_blk))
name := fmt.Sprintf("packer %s", hex.EncodeToString(identifier.NewUUID().Raw()))
keyId, err := client.CreateKey(name, pub_pem)
if err != nil {
ui.Error(err.Error())
return multistep.ActionHalt
}
// We use this to check cleanup
s.keyId = keyId
log.Printf("temporary ssh key name: %s", name)
// Remember some state for the future
state["keyId"] = keyId
state["privateKey"] = string(pem.EncodeToMemory(&priv_blk))
return multistep.ActionContinue
}
func (s *stepCreateSSHKey) Cleanup(state map[string]interface{}) {
// If no key name is set, then we never created it, so just return
if s.keyId == 0 {
return
}
client := state["client"].(*DigitalOceanClient)
ui := state["ui"].(packer.Ui)
ui.Say("Deleting temporary ssh key...")
err := client.DestroyKey(s.keyId)
if err != nil {
ui.Error(fmt.Sprintf(
"Error cleaning up ssh key. Please delete the key manually: %s", s.keyId))
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment