Commit 9d9e1617 authored by Igor's avatar Igor Committed by Nick Thomas

Support calling internal api using HTTP

parent 6e9b4dec
......@@ -17,8 +17,7 @@ import (
)
var (
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
......@@ -46,7 +45,7 @@ var (
)
func TestExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
......@@ -79,7 +78,7 @@ func TestExecute(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
......@@ -91,7 +90,7 @@ func TestExecute(t *testing.T) {
}
func TestFailingExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
......@@ -119,7 +118,7 @@ func TestFailingExecute(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
......
......@@ -18,12 +18,10 @@ import (
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
requests []testserver.TestRequestHandler
)
func setup(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
......@@ -66,7 +64,7 @@ const (
func TestExecute(t *testing.T) {
setup(t)
cleanup, err := testserver.StartSocketHttpServer(requests)
cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
......@@ -124,7 +122,7 @@ func TestExecute(t *testing.T) {
output := &bytes.Buffer{}
input := bytes.NewBufferString(tc.answer)
cmd := &Command{Config: testConfig, Args: tc.arguments}
cmd := &Command{Config: &config.Config{GitlabUrl: url}, Args: tc.arguments}
err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
......
......@@ -22,15 +22,23 @@ type MigrationConfig struct {
Features []string `yaml:"features"`
}
type HttpSettingsConfig struct {
User string `yaml:"user"`
Password string `yaml:"password"`
ReadTimeoutSeconds uint64 `yaml:"read_timeout"`
}
type Config struct {
RootDir string
LogFile string `yaml:"log_file"`
LogFormat string `yaml:"log_format"`
Migration MigrationConfig `yaml:"migration"`
GitlabUrl string `yaml:"gitlab_url"`
GitlabTracing string `yaml:"gitlab_tracing"`
SecretFilePath string `yaml:"secret_file"`
Secret string `yaml:"secret"`
LogFile string `yaml:"log_file"`
LogFormat string `yaml:"log_format"`
Migration MigrationConfig `yaml:"migration"`
GitlabUrl string `yaml:"gitlab_url"`
GitlabTracing string `yaml:"gitlab_tracing"`
SecretFilePath string `yaml:"secret_file"`
Secret string `yaml:"secret"`
HttpSettings HttpSettingsConfig `yaml:"http_settings"`
HttpClient *HttpClient
}
func New() (*Config, error) {
......@@ -51,7 +59,7 @@ func (c *Config) FeatureEnabled(featureName string) bool {
return false
}
if !strings.HasPrefix(c.GitlabUrl, "http+unix://") {
if !strings.HasPrefix(c.GitlabUrl, "http+unix://") && !strings.HasPrefix(c.GitlabUrl, "http://") {
return false
}
......
......@@ -24,12 +24,13 @@ func TestParseConfig(t *testing.T) {
defer cleanup()
testCases := []struct {
yaml string
path string
format string
gitlabUrl string
migration MigrationConfig
secret string
yaml string
path string
format string
gitlabUrl string
migration MigrationConfig
secret string
httpSettings HttpSettingsConfig
}{
{
path: path.Join(testRoot, "gitlab-shell.log"),
......@@ -86,6 +87,13 @@ func TestParseConfig(t *testing.T) {
format: "text",
secret: "an inline secret",
},
{
yaml: "http_settings:\n user: user_basic_auth\n password: password_basic_auth\n read_timeout: 500",
path: path.Join(testRoot, "gitlab-shell.log"),
format: "text",
secret: "default-secret-content",
httpSettings: HttpSettingsConfig{User: "user_basic_auth", Password: "password_basic_auth", ReadTimeoutSeconds: 500},
},
}
for _, tc := range testCases {
......@@ -101,6 +109,7 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, tc.format, cfg.LogFormat)
assert.Equal(t, tc.gitlabUrl, cfg.GitlabUrl)
assert.Equal(t, tc.secret, cfg.Secret)
assert.Equal(t, tc.httpSettings, cfg.HttpSettings)
})
}
}
......@@ -139,6 +148,15 @@ func TestFeatureEnabled(t *testing.T) {
feature: "discover",
expectEnabled: false,
},
{
desc: "When the protocol is http and the feature enabled",
config: &Config{
GitlabUrl: "http://localhost:3000",
Migration: MigrationConfig{Enabled: true, Features: []string{"discover"}},
},
feature: "discover",
expectEnabled: true,
},
{
desc: "When the protocol is not supported",
config: &Config{
......
package config
import (
"context"
"net"
"net/http"
"strings"
"time"
)
const (
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
HttpProtocol = "http://"
defaultReadTimeoutSeconds = 300
)
type HttpClient struct {
HttpClient *http.Client
Host string
}
func (c *Config) GetHttpClient() *HttpClient {
if c.HttpClient != nil {
return c.HttpClient
}
var transport *http.Transport
var host string
if strings.HasPrefix(c.GitlabUrl, UnixSocketProtocol) {
transport, host = c.buildSocketTransport()
} else if strings.HasPrefix(c.GitlabUrl, HttpProtocol) {
transport, host = c.buildHttpTransport()
} else {
return nil
}
httpClient := &http.Client{
Transport: transport,
Timeout: c.readTimeout(),
}
client := &HttpClient{HttpClient: httpClient, Host: host}
c.HttpClient = client
return client
}
func (c *Config) buildSocketTransport() (*http.Transport, string) {
socketPath := strings.TrimPrefix(c.GitlabUrl, UnixSocketProtocol)
transport := &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", socketPath)
},
}
return transport, socketBaseUrl
}
func (c *Config) buildHttpTransport() (*http.Transport, string) {
return &http.Transport{}, c.GitlabUrl
}
func (c *Config) readTimeout() time.Duration {
timeoutSeconds := c.HttpSettings.ReadTimeoutSeconds
if timeoutSeconds == 0 {
timeoutSeconds = defaultReadTimeoutSeconds
}
return time.Duration(timeoutSeconds) * time.Second
}
package config
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadTimeout(t *testing.T) {
expectedSeconds := uint64(300)
config := &Config{
GitlabUrl: "http://localhost:3000",
HttpSettings: HttpSettingsConfig{ReadTimeoutSeconds: expectedSeconds},
}
client := config.GetHttpClient()
require.NotNil(t, client)
assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.HttpClient.Timeout)
}
package gitlabnet
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
......@@ -15,22 +17,24 @@ const (
secretHeaderName = "Gitlab-Shared-Secret"
)
type GitlabClient interface {
Get(path string) (*http.Response, error)
Post(path string, data interface{}) (*http.Response, error)
}
type ErrorResponse struct {
Message string `json:"message"`
}
func GetClient(config *config.Config) (GitlabClient, error) {
url := config.GitlabUrl
if strings.HasPrefix(url, UnixSocketProtocol) {
return buildSocketClient(config), nil
type GitlabClient struct {
httpClient *http.Client
config *config.Config
host string
}
func GetClient(config *config.Config) (*GitlabClient, error) {
client := config.GetHttpClient()
if client == nil {
return nil, fmt.Errorf("Unsupported protocol")
}
return nil, fmt.Errorf("Unsupported protocol")
return &GitlabClient{httpClient: client.HttpClient, config: config, host: client.Host}, nil
}
func normalizePath(path string) string {
......@@ -44,6 +48,27 @@ func normalizePath(path string) string {
return path
}
func newRequest(method, host, path string, data interface{}) (*http.Request, error) {
path = normalizePath(path)
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
jsonReader = bytes.NewReader(jsonData)
}
request, err := http.NewRequest(method, host+path, jsonReader)
if err != nil {
return nil, err
}
return request, nil
}
func parseError(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
......@@ -59,11 +84,32 @@ func parseError(resp *http.Response) error {
}
func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
func (c *GitlabClient) Get(path string) (*http.Response, error) {
return c.doRequest("GET", path, nil)
}
func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) {
return c.doRequest("POST", path, data)
}
func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.Response, error) {
request, err := newRequest(method, c.host, path, data)
if err != nil {
return nil, err
}
user, password := c.config.HttpSettings.User, c.config.HttpSettings.Password
if user != "" && password != "" {
request.SetBasicAuth(user, password)
}
encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.config.Secret))
request.Header.Set(secretHeaderName, encodedSecret)
response, err := client.Do(request)
request.Header.Add("Content-Type", "application/json")
request.Close = true
response, err := c.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("Internal API unreachable")
}
......
......@@ -61,37 +61,44 @@ func TestClients(t *testing.T) {
},
},
}
testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
testCases := []struct {
desc string
client GitlabClient
server func([]testserver.TestRequestHandler) (func(), error)
secret string
server func([]testserver.TestRequestHandler) (func(), string, error)
}{
{
desc: "Socket client",
client: buildSocketClient(testConfig),
secret: "sssh, it's a secret",
server: testserver.StartSocketHttpServer,
},
{
desc: "Http client",
secret: "sssh, it's a secret",
server: testserver.StartHttpServer,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cleanup, err := tc.server(requests)
cleanup, url, err := tc.server(requests)
defer cleanup()
require.NoError(t, err)
testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client)
testSuccessfulPost(t, tc.client)
testMissing(t, tc.client)
testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client)
client, err := GetClient(&config.Config{GitlabUrl: url, Secret: tc.secret})
require.NoError(t, err)
testBrokenRequest(t, client)
testSuccessfulGet(t, client)
testSuccessfulPost(t, client)
testMissing(t, client)
testErrorMessage(t, client)
testAuthenticationHeader(t, client)
})
}
}
func testSuccessfulGet(t *testing.T, client GitlabClient) {
func testSuccessfulGet(t *testing.T, client *GitlabClient) {
t.Run("Successful get", func(t *testing.T) {
response, err := client.Get("/hello")
defer response.Body.Close()
......@@ -105,7 +112,7 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
})
}
func testSuccessfulPost(t *testing.T, client GitlabClient) {
func testSuccessfulPost(t *testing.T, client *GitlabClient) {
t.Run("Successful Post", func(t *testing.T) {
data := map[string]string{"key": "value"}
......@@ -121,7 +128,7 @@ func testSuccessfulPost(t *testing.T, client GitlabClient) {
})
}
func testMissing(t *testing.T, client GitlabClient) {
func testMissing(t *testing.T, client *GitlabClient) {
t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
......@@ -135,7 +142,7 @@ func testMissing(t *testing.T, client GitlabClient) {
})
}
func testErrorMessage(t *testing.T, client GitlabClient) {
func testErrorMessage(t *testing.T, client *GitlabClient) {
t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
......@@ -149,7 +156,7 @@ func testErrorMessage(t *testing.T, client GitlabClient) {
})
}
func testBrokenRequest(t *testing.T, client GitlabClient) {
func testBrokenRequest(t *testing.T, client *GitlabClient) {
t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
......@@ -163,7 +170,7 @@ func testBrokenRequest(t *testing.T, client GitlabClient) {
})
}
func testAuthenticationHeader(t *testing.T, client GitlabClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
......
......@@ -13,7 +13,7 @@ import (
type Client struct {
config *config.Config
client gitlabnet.GitlabClient
client *gitlabnet.GitlabClient
}
type Response struct {
......
......@@ -15,12 +15,10 @@ import (
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
requests []testserver.TestRequestHandler
)
func init() {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
......@@ -121,10 +119,10 @@ func TestErrorResponses(t *testing.T) {
}
func setup(t *testing.T) (*Client, func()) {
cleanup, err := testserver.StartSocketHttpServer(requests)
cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
client, err := NewClient(testConfig)
client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err)
return client, cleanup
......
package gitlabnet
import (
"encoding/base64"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
const (
username = "basic_auth_user"
password = "basic_auth_password"
)
func TestBasicAuthSettings(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/get_endpoint",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
fmt.Fprint(w, r.Header.Get("Authorization"))
},
},
{
Path: "/api/v4/internal/post_endpoint",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
fmt.Fprint(w, r.Header.Get("Authorization"))
},
},
}
config := &config.Config{HttpSettings: config.HttpSettingsConfig{User: username, Password: password}}
client, cleanup := setup(t, config, requests)
defer cleanup()
response, err := client.Get("/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
response, err = client.Post("/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
func testBasicAuthHeaders(t *testing.T, response *http.Response) {
defer response.Body.Close()
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
assert.NoError(t, err)
headerParts := strings.Split(string(responseBody), " ")
assert.Equal(t, "Basic", headerParts[0])
credentials, err := base64.StdEncoding.DecodeString(headerParts[1])
require.NoError(t, err)
assert.Equal(t, username+":"+password, string(credentials))
}
func TestEmptyBasicAuthSettings(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/empty_basic_auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "", r.Header.Get("Authorization"))
},
},
}
client, cleanup := setup(t, &config.Config{}, requests)
defer cleanup()
_, err := client.Get("/empty_basic_auth")
require.NoError(t, err)
}
func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) {
cleanup, url, err := testserver.StartHttpServer(requests)
require.NoError(t, err)
config.GitlabUrl = url
client, err := GetClient(config)
require.NoError(t, err)
return client, cleanup
}
package gitlabnet
import (
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
// We need to set the base URL to something starting with HTTP, the host
// itself is ignored as we're talking over a socket.
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
)
type GitlabSocketClient struct {
httpClient *http.Client
config *config.Config
}
func buildSocketClient(config *config.Config) *GitlabSocketClient {
path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", path)
},
},
}
return &GitlabSocketClient{httpClient: httpClient, config: config}
}
func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
path = normalizePath(path)
request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
path = normalizePath(path)
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
request.Header.Add("Content-Type", "application/json")
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
......@@ -5,6 +5,7 @@ import (
"log"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"path/filepath"
......@@ -12,7 +13,7 @@ import (
var (
tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
TestSocket = path.Join(tempDir, "internal.sock")
testSocket = path.Join(tempDir, "internal.sock")
)
type TestRequestHandler struct {
......@@ -20,14 +21,14 @@ type TestRequestHandler struct {
Handler func(w http.ResponseWriter, r *http.Request)
}
func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
return nil, err
func StartSocketHttpServer(handlers []TestRequestHandler) (func(), string, error) {
if err := os.MkdirAll(filepath.Dir(testSocket), 0700); err != nil {
return nil, "", err
}
socketListener, err := net.Listen("unix", TestSocket)
socketListener, err := net.Listen("unix", testSocket)
if err != nil {
return nil, err
return nil, "", err
}
server := http.Server{
......@@ -38,7 +39,15 @@ func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
}
go server.Serve(socketListener)
return cleanupSocket, nil
url := "http+unix://" + testSocket
return cleanupSocket, url, nil
}
func StartHttpServer(handlers []TestRequestHandler) (func(), string, error) {
server := httptest.NewServer(buildHandler(handlers))
return server.Close, server.URL, nil
}
func cleanupSocket() {
......
......@@ -15,7 +15,7 @@ import (
type Client struct {
config *config.Config
client gitlabnet.GitlabClient
client *gitlabnet.GitlabClient
}
type Response struct {
......
......@@ -17,12 +17,10 @@ import (
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
requests []testserver.TestRequestHandler
)
func initialize(t *testing.T) {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/two_factor_recovery_codes",
......@@ -151,10 +149,10 @@ func TestErrorResponses(t *testing.T) {
func setup(t *testing.T) (*Client, func()) {
initialize(t)
cleanup, err := testserver.StartSocketHttpServer(requests)
cleanup, url, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
client, err := NewClient(testConfig)
client, err := NewClient(&config.Config{GitlabUrl: url})
require.NoError(t, err)
return client, cleanup
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment