testhelper.go 4.79 KB
Newer Older
1 2 3
package testhelper

import (
4 5
	"bufio"
	"bytes"
6 7 8
	"errors"
	"fmt"
	"io/ioutil"
9 10
	"net/http"
	"net/http/httptest"
11 12 13
	"os"
	"os/exec"
	"path"
14
	"regexp"
15
	"runtime"
16
	"strings"
17
	"testing"
18

19 20
	log "github.com/sirupsen/logrus"

21
	"gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
22 23
)

24 25
func ConfigureSecret() {
	secret.SetPath(path.Join(RootDir(), "testdata/test-secret"))
26 27
}

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
var extractPatchSeriesMatcher = regexp.MustCompile(`^From (\w+)`)

// AssertPatchSeries takes a `git format-patch` blob, extracts the From xxxxx
// lines and compares the SHAs to expected list.
func AssertPatchSeries(t *testing.T, blob []byte, expected ...string) {
	var actual []string
	footer := make([]string, 3)

	scanner := bufio.NewScanner(bytes.NewReader(blob))

	for scanner.Scan() {
		line := scanner.Text()
		if matches := extractPatchSeriesMatcher.FindStringSubmatch(line); len(matches) == 2 {
			actual = append(actual, matches[1])
		}
		footer = []string{footer[1], footer[2], line}
	}

	if strings.Join(actual, "\n") != strings.Join(expected, "\n") {
		t.Fatalf("Patch series differs. Expected: %v. Got: %v", expected, actual)
	}

	// Check the last returned patch is complete
	// Don't assert on the final line, it is a git version
	if footer[0] != "-- " {
		t.Fatalf("Expected end of patch, found: \n\t%q", strings.Join(footer, "\n\t"))
	}
}

57 58 59 60 61 62 63 64 65 66 67 68
func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
	if response.Code != expectedCode {
		t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code)
	}
}

func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) {
	if response.Body.String() != expectedBody {
		t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String())
	}
}

69 70 71 72 73 74
func AssertResponseBodyRegexp(t *testing.T, response *httptest.ResponseRecorder, expectedBody *regexp.Regexp) {
	if !expectedBody.MatchString(response.Body.String()) {
		t.Fatalf("for HTTP request expected to receive body matching %q, got %q instead", expectedBody.String(), response.Body.String())
	}
}

75
func AssertResponseWriterHeader(t *testing.T, w http.ResponseWriter, header string, expected ...string) {
76 77
	actual := w.Header()[http.CanonicalHeaderKey(header)]

78 79 80
	assertHeaderExists(t, header, actual, expected)
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94
func AssertResponseHeader(t *testing.T, w interface{}, header string, expected ...string) {
	var actual []string

	header = http.CanonicalHeaderKey(header)

	if resp, ok := w.(*http.Response); ok {
		actual = resp.Header[header]
	} else if resp, ok := w.(http.ResponseWriter); ok {
		actual = resp.Header()[header]
	} else if resp, ok := w.(*httptest.ResponseRecorder); ok {
		actual = resp.Header()[header]
	} else {
		t.Fatalf("invalid type of w passed AssertResponseHeader")
	}
95 96 97 98 99

	assertHeaderExists(t, header, actual, expected)
}

func assertHeaderExists(t *testing.T, header string, actual, expected []string) {
100 101 102 103 104 105 106 107
	if len(expected) != len(actual) {
		t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
	}

	for i, value := range expected {
		if value != actual[i] {
			t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual)
		}
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
	}
}

func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if url != nil && !url.MatchString(r.URL.Path) {
			log.Println("UPSTREAM", r.Method, r.URL, "DENY")
			w.WriteHeader(404)
			return
		}

		if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
			log.Println("UPSTREAM", r.Method, r.URL, "DENY")
			w.WriteHeader(403)
			return
		}

		handler(w, r)
	}))
}
128 129

func BuildExecutables() (func(), error) {
130
	rootDir := RootDir()
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

	// This method will be invoked more than once due to Go test
	// parallelization. We must use a unique temp directory for each
	// invokation so that they do not trample each other's builds.
	testDir, err := ioutil.TempDir("", "gitlab-workhorse-test")
	if err != nil {
		return nil, errors.New("could not create temp directory")
	}

	makeCmd := exec.Command("make", "BUILD_DIR="+testDir)
	makeCmd.Dir = rootDir
	makeCmd.Stderr = os.Stderr
	makeCmd.Stdout = os.Stdout
	if err := makeCmd.Run(); err != nil {
		return nil, fmt.Errorf("failed to run %v in %v", makeCmd, rootDir)
	}

	oldPath := os.Getenv("PATH")
	testPath := fmt.Sprintf("%s:%s", testDir, oldPath)
Jacob Vosmaer's avatar
Jacob Vosmaer committed
150 151
	if err := os.Setenv("PATH", testPath); err != nil {
		return nil, fmt.Errorf("failed to set PATH to %v", testPath)
152 153 154 155 156 157 158
	}

	return func() {
		os.Setenv("PATH", oldPath)
		os.RemoveAll(testDir)
	}, nil
}
159 160 161 162 163 164 165 166

func RootDir() string {
	_, currentFile, _, ok := runtime.Caller(0)
	if !ok {
		panic(errors.New("RootDir: calling runtime.Caller failed"))
	}
	return path.Join(path.Dir(currentFile), "../..")
}