Commit f237aba6 authored by Nick Thomas's avatar Nick Thomas

Merge branch 'bvl-discover-command' into 'master'

Call gitlab "/internal/discover" from go

Closes #175

See merge request gitlab-org/gitlab-shell!283
parents 049beb74 83c0f18e
......@@ -7,25 +7,28 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
var (
binDir string
rootDir string
binDir string
rootDir string
reporter *reporting.Reporter
)
func init() {
binDir = filepath.Dir(os.Args[0])
rootDir = filepath.Dir(binDir)
reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr}
}
// rubyExec will never return. It either replaces the current process with a
// Ruby interpreter, or outputs an error and kills the process.
func execRuby() {
cmd := &fallback.Command{}
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Failed to exec: %v\n", err)
if err := cmd.Execute(reporter); err != nil {
fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err)
os.Exit(1)
}
}
......@@ -35,7 +38,7 @@ func main() {
// warning as this isn't something we can sustain indefinitely
config, err := config.NewFromDir(rootDir)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to read config, falling back to gitlab-shell-ruby")
fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby")
execRuby()
}
......@@ -43,14 +46,14 @@ func main() {
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
fmt.Fprintf(os.Stderr, "%v\n", err)
fmt.Fprintf(reporter.ErrOut, "%v\n", err)
os.Exit(1)
}
// The command will write to STDOUT on execution or replace the current
// process in case of the `fallback.Command`
if err = cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
if err = cmd.Execute(reporter); err != nil {
fmt.Fprintf(reporter.ErrOut, "%v\n", err)
os.Exit(1)
}
}
......@@ -4,11 +4,12 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
type Command interface {
Execute() error
Execute(*reporting.Reporter) error
}
func New(arguments []string, config *config.Config) (Command, error) {
......
......@@ -4,7 +4,9 @@ import (
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
type Command struct {
......@@ -12,6 +14,34 @@ type Command struct {
Args *commandargs.CommandArgs
}
func (c *Command) Execute() error {
return fmt.Errorf("No feature is implemented yet")
func (c *Command) Execute(reporter *reporting.Reporter) error {
response, err := c.getUserInfo()
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
if response.IsAnonymous() {
fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n")
} else {
fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username)
}
return nil
}
func (c *Command) getUserInfo() (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
if c.Args.GitlabKeyId != "" {
return client.GetByKeyId(c.Args.GitlabKeyId)
} else if c.Args.GitlabUsername != "" {
return client.GetByUsername(c.Args.GitlabUsername)
} else {
// There was no 'who' information, this matches the ruby error
// message.
return nil, fmt.Errorf("who='' is invalid")
}
}
package discover
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
var (
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" {
body := map[string]interface{}{
"id": 2,
"username": "alex-doe",
"name": "Alex Doe",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken_message" {
body := map[string]string{
"message": "Forbidden!",
}
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken" {
w.WriteHeader(http.StatusInternalServerError)
} else {
fmt.Fprint(w, "null")
}
},
},
}
)
func TestExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
testCases := []struct {
desc string
arguments *commandargs.CommandArgs
expectedOutput string
}{
{
desc: "With a known username",
arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"},
expectedOutput: "Welcome to GitLab, @alex-doe!\n",
},
{
desc: "With a known key id",
arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
expectedOutput: "Welcome to GitLab, @alex-doe!\n",
},
{
desc: "With an unknown key",
arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
expectedOutput: "Welcome to GitLab, Anonymous!\n",
},
{
desc: "With an unknown username",
arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"},
expectedOutput: "Welcome to GitLab, Anonymous!\n",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&reporting.Reporter{Out: buffer})
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String())
})
}
}
func TestFailingExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
testCases := []struct {
desc string
arguments *commandargs.CommandArgs
expectedError string
}{
{
desc: "With missing arguments",
arguments: &commandargs.CommandArgs{},
expectedError: "Failed to get username: who='' is invalid",
},
{
desc: "When the API returns an error",
arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"},
expectedError: "Failed to get username: Forbidden!",
},
{
desc: "When the API fails",
arguments: &commandargs.CommandArgs{GitlabUsername: "broken"},
expectedError: "Failed to get username: Internal API error (500)",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
err := cmd.Execute(&reporting.Reporter{Out: buffer})
assert.Empty(t, buffer.String())
assert.EqualError(t, err, tc.expectedError)
})
}
}
......@@ -4,6 +4,8 @@ import (
"os"
"path/filepath"
"syscall"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
)
type Command struct{}
......@@ -12,7 +14,7 @@ var (
binDir = filepath.Dir(os.Args[0])
)
func (c *Command) Execute() error {
func (c *Command) Execute(_ *reporting.Reporter) error {
rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby")
execErr := syscall.Exec(rubyCmd, os.Args, os.Environ())
return execErr
......
package reporting
import "io"
type Reporter struct {
Out io.Writer
ErrOut io.Writer
}
package gitlabnet
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
internalApiPath = "/api/v4/internal"
secretHeaderName = "Gitlab-Shared-Secret"
)
type GitlabClient interface {
Get(path string) (*http.Response, error)
// TODO: implement posts
// Post(path string) (http.Response, error)
}
type ErrorResponse struct {
Message string `json:"message"`
}
func GetClient(config *config.Config) (GitlabClient, error) {
url := config.GitlabUrl
if strings.HasPrefix(url, UnixSocketProtocol) {
return buildSocketClient(config), nil
}
return nil, fmt.Errorf("Unsupported protocol")
}
func normalizePath(path string) string {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
if !strings.HasPrefix(path, internalApiPath) {
path = internalApiPath + path
}
return path
}
func parseError(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
}
defer resp.Body.Close()
parsedResponse := &ErrorResponse{}
if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
} else {
return fmt.Errorf(parsedResponse.Message)
}
}
func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
request.Header.Set(secretHeaderName, encodedSecret)
response, err := client.Do(request)
if err != nil {
return nil, fmt.Errorf("Internal API unreachable")
}
if err := parseError(response); err != nil {
return nil, err
}
return response, nil
}
package gitlabnet
import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
func TestClients(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/hello",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello")
},
},
{
Path: "/api/v4/internal/auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, r.Header.Get(secretHeaderName))
},
},
{
Path: "/api/v4/internal/error",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
body := map[string]string{
"message": "Don't do that",
}
json.NewEncoder(w).Encode(body)
},
},
{
Path: "/api/v4/internal/broken",
Handler: func(w http.ResponseWriter, r *http.Request) {
panic("Broken")
},
},
}
testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
testCases := []struct {
desc string
client GitlabClient
server func([]testserver.TestRequestHandler) (func(), error)
}{
{
desc: "Socket client",
client: buildSocketClient(testConfig),
server: testserver.StartSocketHttpServer,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cleanup, err := tc.server(requests)
defer cleanup()
require.NoError(t, err)
testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client)
testMissing(t, tc.client)
testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client)
})
}
}
func testSuccessfulGet(t *testing.T, client GitlabClient) {
t.Run("Successful get", func(t *testing.T) {
response, err := client.Get("/hello")
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, string(responseBody), "Hello")
})
}
func testMissing(t *testing.T, client GitlabClient) {
t.Run("Missing error", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
})
}
func testErrorMessage(t *testing.T, client GitlabClient) {
t.Run("Error with message", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
}
func testBrokenRequest(t *testing.T, client GitlabClient) {
t.Run("Broken request", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
})
}
func testAuthenticationHeader(t *testing.T, client GitlabClient) {
t.Run("Authentication headers", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
require.NoError(t, err)
header, err := base64.StdEncoding.DecodeString(string(responseBody))
require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header))
})
}
package discover
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
)
type Client struct {
config *config.Config
client gitlabnet.GitlabClient
}
type Response struct {
UserId int64 `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
}
func NewClient(config *config.Config) (*Client, error) {
client, err := gitlabnet.GetClient(config)
if err != nil {
return nil, fmt.Errorf("Error creating http client: %v", err)
}
return &Client{config: config, client: client}, nil
}
func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{}
params.Add("key_id", keyId)
return c.getResponse(params)
}
func (c *Client) GetByUsername(username string) (*Response, error) {
params := url.Values{}
params.Add("username", username)
return c.getResponse(params)
}
func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
parsedResponse := &Response{}
if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
return nil, err
} else {
return parsedResponse, nil
}
}
func (c *Client) getResponse(params url.Values) (*Response, error) {
path := "/discover?" + params.Encode()
response, err := c.client.Get(path)
if err != nil {
return nil, err
}
defer response.Body.Close()
parsedResponse, err := c.parseResponse(response)
if err != nil {
return nil, fmt.Errorf("Parsing failed")
}
return parsedResponse, nil
}
func (r *Response) IsAnonymous() bool {
return r.UserId < 1
}
package discover
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
)
func init() {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("key_id") == "1" {
body := &Response{
UserId: 2,
Username: "alex-doe",
Name: "Alex Doe",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "jane-doe" {
body := &Response{
UserId: 1,
Username: "jane-doe",
Name: "Jane Doe",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken_message" {
w.WriteHeader(http.StatusForbidden)
body := &gitlabnet.ErrorResponse{
Message: "Not allowed!",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken_json" {
w.Write([]byte("{ \"message\": \"broken json!\""))
} else if r.URL.Query().Get("username") == "broken_empty" {
w.WriteHeader(http.StatusForbidden)
} else {
fmt.Fprint(w, "null")
}
},
},
}
}
func TestGetByKeyId(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByKeyId("1")
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
}
func TestGetByUsername(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByUsername("jane-doe")
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
}
func TestMissingUser(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByUsername("missing")
assert.NoError(t, err)
assert.True(t, result.IsAnonymous())
}
func TestErrorResponses(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
testCases := []struct {
desc string
fakeUsername string
expectedError string
}{
{
desc: "A response with an error message",
fakeUsername: "broken_message",
expectedError: "Not allowed!",
},
{
desc: "A response with bad JSON",
fakeUsername: "broken_json",
expectedError: "Parsing failed",
},
{
desc: "An error response without message",
fakeUsername: "broken_empty",
expectedError: "Internal API error (403)",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
resp, err := client.GetByUsername(tc.fakeUsername)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)
})
}
}
func setup(t *testing.T) (*Client, func()) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
client, err := NewClient(testConfig)
require.NoError(t, err)
return client, cleanup
}
package gitlabnet
import (
"context"
"net"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
// We need to set the base URL to something starting with HTTP, the host
// itself is ignored as we're talking over a socket.
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
)
type GitlabSocketClient struct {
httpClient *http.Client
config *config.Config
}
func buildSocketClient(config *config.Config) *GitlabSocketClient {
path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", path)
},
},
}
return &GitlabSocketClient{httpClient: httpClient, config: config}
}
func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
path = normalizePath(path)
request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
package testserver
import (
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path"
"path/filepath"
)
var (
tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
TestSocket = path.Join(tempDir, "internal.sock")
)
type TestRequestHandler struct {
Path string
Handler func(w http.ResponseWriter, r *http.Request)
}
func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
return nil, err
}
socketListener, err := net.Listen("unix", TestSocket)
if err != nil {
return nil, err
}
server := http.Server{
Handler: buildHandler(handlers),
// We'll put this server through some nasty stuff we don't want
// in our test output
ErrorLog: log.New(ioutil.Discard, "", 0),
}
go server.Serve(socketListener)
return cleanupSocket, nil
}
func cleanupSocket() {
os.RemoveAll(tempDir)
}
func buildHandler(handlers []TestRequestHandler) http.Handler {
h := http.NewServeMux()
for _, handler := range handlers {
h.HandleFunc(handler.Path, handler.Handler)
}
return h
}
......@@ -30,12 +30,19 @@ describe 'bin/gitlab-shell' do
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
@server.mount_proc('/api/v4/internal/discover') do |req, res|
if req.query['key_id'] == '100' ||
req.query['user_id'] == '10' ||
req.query['username'] == 'someuser'
identifier = req.query['key_id'] || req.query['username'] || req.query['user_id']
known_identifiers = %w(10 someuser 100)
if known_identifiers.include?(identifier)
res.status = 200
res.content_type = 'application/json'
res.body = '{"id":1, "name": "Some User", "username": "someuser"}'
elsif identifier == 'broken_message'
res.status = 401
res.body = '{"message": "Forbidden!"}'
elsif identifier && identifier != 'broken'
res.status = 200
res.content_type = 'application/json'
res.body = 'null'
else
res.status = 500
end
......@@ -145,11 +152,7 @@ describe 'bin/gitlab-shell' do
)
end
it_behaves_like 'results with keys' do
before do
pending
end
end
it_behaves_like 'results with keys'
it 'outputs "Only ssh allowed"' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-someuser"], env: {})
......@@ -157,6 +160,20 @@ describe 'bin/gitlab-shell' do
expect(stderr).to eq("Only ssh allowed\n")
expect(status).not_to be_success
end
it 'returns an error message when the API call fails with a message' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken_message"])
expect(stderr).to match(/Failed to get username: Forbidden!/)
expect(status).not_to be_success
end
it 'returns an error message when the API call fails without a message' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken"])
expect(stderr).to match(/Failed to get username: Internal API error \(500\)/)
expect(status).not_to be_success
end
end
def run!(args, env: {'SSH_CONNECTION' => 'fake'})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment