Commit 53511f36 authored by Bob Van Landuyt's avatar Bob Van Landuyt

Detect user based on key, username or id

This allows gitlab-shell to be called with an argument of the format
`key-123` or `username-name`.

When called in this way, `gitlab-shell` will call the GitLab internal
API. If the API responds with user information, it will print a
welcome message including the username.

If the API responds with a successful but empty response, gitlab-shell
will print a welcome message for an anonymous user.

If the API response includes an error message in JSON, this message
will be printed to stderr.

If the API call fails, an error message including the status code will
be printed to stderr.
parent 049beb74
......@@ -2,9 +2,12 @@ package discover
import (
"fmt"
"io"
"os"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
type Command struct {
......@@ -12,6 +15,38 @@ type Command struct {
Args *commandargs.CommandArgs
}
var (
output io.Writer = os.Stdout
)
func (c *Command) Execute() error {
return fmt.Errorf("No feature is implemented yet")
response, err := c.getUserInfo()
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
if response.IsAnonymous() {
fmt.Fprintf(output, "Welcome to GitLab, Anonymous!\n")
} else {
fmt.Fprintf(output, "Welcome to GitLab, @%s!\n", response.Username)
}
return nil
}
func (c *Command) getUserInfo() (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
if c.Args.GitlabKeyId != "" {
return client.GetByKeyId(c.Args.GitlabKeyId)
} else if c.Args.GitlabUsername != "" {
return client.GetByUsername(c.Args.GitlabUsername)
} else {
// There was no 'who' information, this matches the ruby error
// message.
return nil, fmt.Errorf("who='' is invalid")
}
}
package discover
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
var (
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" {
body := map[string]interface{}{
"id": 2,
"username": "alex-doe",
"name": "Alex Doe",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken_message" {
body := map[string]string{
"message": "Forbidden!",
}
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "broken" {
w.WriteHeader(http.StatusInternalServerError)
} else {
fmt.Fprint(w, "null")
}
},
},
}
)
func TestExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
testCases := []struct {
desc string
arguments *commandargs.CommandArgs
expectedOutput string
}{
{
desc: "With a known username",
arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"},
expectedOutput: "Welcome to GitLab, @alex-doe!\n",
},
{
desc: "With a known key id",
arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
expectedOutput: "Welcome to GitLab, @alex-doe!\n",
},
{
desc: "With an unknown key",
arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
expectedOutput: "Welcome to GitLab, Anonymous!\n",
},
{
desc: "With an unknown username",
arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"},
expectedOutput: "Welcome to GitLab, Anonymous!\n",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
buffer := &bytes.Buffer{}
output = buffer
cmd := &Command{Config: testConfig, Args: tc.arguments}
err := cmd.Execute()
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String())
})
}
}
func TestFailingExecute(t *testing.T) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
defer cleanup()
testCases := []struct {
desc string
arguments *commandargs.CommandArgs
expectedError string
}{
{
desc: "With missing arguments",
arguments: &commandargs.CommandArgs{},
expectedError: "Failed to get username: who='' is invalid",
},
{
desc: "When the API returns an error",
arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"},
expectedError: "Failed to get username: Forbidden!",
},
{
desc: "When the API fails",
arguments: &commandargs.CommandArgs{GitlabUsername: "broken"},
expectedError: "Failed to get username: Internal API error (500)",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
err := cmd.Execute()
assert.EqualError(t, err, tc.expectedError)
})
}
}
package gitlabnet
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
internalApiPath = "/api/v4/internal"
secretHeaderName = "Gitlab-Shared-Secret"
)
type GitlabClient interface {
Get(path string) (*http.Response, error)
// TODO: implement posts
// Post(path string) (http.Response, error)
}
type ErrorResponse struct {
Message string `json:"message"`
}
func GetClient(config *config.Config) (GitlabClient, error) {
url := config.GitlabUrl
if strings.HasPrefix(url, UnixSocketProtocol) {
return buildSocketClient(config), nil
}
return nil, fmt.Errorf("Unsupported protocol")
}
func normalizePath(path string) string {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
if !strings.HasPrefix(path, internalApiPath) {
path = internalApiPath + path
}
return path
}
func parseError(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
}
defer resp.Body.Close()
parsedResponse := &ErrorResponse{}
if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
} else {
return fmt.Errorf(parsedResponse.Message)
}
}
func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
request.Header.Set(secretHeaderName, encodedSecret)
response, err := client.Do(request)
if err != nil {
return nil, fmt.Errorf("Internal API unreachable")
}
if err := parseError(response); err != nil {
return nil, err
}
return response, nil
}
package gitlabnet
import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
func TestClients(t *testing.T) {
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/hello",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello")
},
},
{
Path: "/api/v4/internal/auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, r.Header.Get(secretHeaderName))
},
},
{
Path: "/api/v4/internal/error",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
body := map[string]string{
"message": "Don't do that",
}
json.NewEncoder(w).Encode(body)
},
},
{
Path: "/api/v4/internal/broken",
Handler: func(w http.ResponseWriter, r *http.Request) {
panic("Broken")
},
},
}
testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
testCases := []struct {
desc string
client GitlabClient
server func([]testserver.TestRequestHandler) (func(), error)
}{
{
desc: "Socket client",
client: buildSocketClient(testConfig),
server: testserver.StartSocketHttpServer,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
cleanup, err := tc.server(requests)
defer cleanup()
require.NoError(t, err)
testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client)
testMissing(t, tc.client)
testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client)
})
}
}
func testSuccessfulGet(t *testing.T, client GitlabClient) {
t.Run("Successful get", func(t *testing.T) {
response, err := client.Get("/hello")
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, string(responseBody), "Hello")
})
}
func testMissing(t *testing.T, client GitlabClient) {
t.Run("Missing error", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
})
}
func testErrorMessage(t *testing.T, client GitlabClient) {
t.Run("Error with message", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
}
func testBrokenRequest(t *testing.T, client GitlabClient) {
t.Run("Broken request", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
})
}
func testAuthenticationHeader(t *testing.T, client GitlabClient) {
t.Run("Authentication headers", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
require.NoError(t, err)
require.NotNil(t, response)
responseBody, err := ioutil.ReadAll(response.Body)
require.NoError(t, err)
header, err := base64.StdEncoding.DecodeString(string(responseBody))
require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header))
})
}
package discover
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
)
type Client struct {
config *config.Config
client gitlabnet.GitlabClient
}
type Response struct {
UserId int64 `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
}
func NewClient(config *config.Config) (*Client, error) {
client, err := gitlabnet.GetClient(config)
if err != nil {
return nil, fmt.Errorf("Error creating http client: %v", err)
}
return &Client{config: config, client: client}, nil
}
func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{}
params.Add("key_id", keyId)
return c.getResponse(params)
}
func (c *Client) GetByUsername(username string) (*Response, error) {
params := url.Values{}
params.Add("username", username)
return c.getResponse(params)
}
func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
defer resp.Body.Close()
parsedResponse := &Response{}
if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
return nil, err
} else {
return parsedResponse, nil
}
}
func (c *Client) getResponse(params url.Values) (*Response, error) {
path := "/discover?" + params.Encode()
response, err := c.client.Get(path)
if err != nil {
return nil, err
}
return c.parseResponse(response)
}
func (r *Response) IsAnonymous() bool {
return r.UserId < 1
}
package discover
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
testConfig *config.Config
requests []testserver.TestRequestHandler
)
func init() {
testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("key_id") == "1" {
body := map[string]interface{}{
"id": 2,
"username": "alex-doe",
"name": "Alex Doe",
}
json.NewEncoder(w).Encode(body)
} else if r.URL.Query().Get("username") == "jane-doe" {
body := map[string]interface{}{
"id": 1,
"username": "jane-doe",
"name": "Jane Doe",
}
json.NewEncoder(w).Encode(body)
} else {
fmt.Fprint(w, "null")
}
},
},
}
}
func TestGetByKeyId(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByKeyId("1")
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
}
func TestGetByUsername(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByUsername("jane-doe")
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
}
func TestMissingUser(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
result, err := client.GetByUsername("missing")
assert.NoError(t, err)
assert.True(t, result.IsAnonymous())
}
func setup(t *testing.T) (*Client, func()) {
cleanup, err := testserver.StartSocketHttpServer(requests)
require.NoError(t, err)
client, err := NewClient(testConfig)
require.NoError(t, err)
return client, cleanup
}
package gitlabnet
import (
"context"
"net"
"net/http"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
const (
// We need to set the base URL to something starting with HTTP, the host
// itself is ignored as we're talking over a socket.
socketBaseUrl = "http://unix"
UnixSocketProtocol = "http+unix://"
)
type GitlabSocketClient struct {
httpClient *http.Client
config *config.Config
}
func buildSocketClient(config *config.Config) *GitlabSocketClient {
path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", path)
},
},
}
return &GitlabSocketClient{httpClient: httpClient, config: config}
}
func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
path = normalizePath(path)
request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
if err != nil {
return nil, err
}
return doRequest(c.httpClient, c.config, request)
}
package testserver
import (
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path"
"path/filepath"
)
var (
tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
TestSocket = path.Join(tempDir, "internal.sock")
)
type TestRequestHandler struct {
Path string
Handler func(w http.ResponseWriter, r *http.Request)
}
func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
return nil, err
}
socketListener, err := net.Listen("unix", TestSocket)
if err != nil {
return nil, err
}
server := http.Server{
Handler: buildHandler(handlers),
// We'll put this server through some nasty stuff we don't want
// in our test output
ErrorLog: log.New(ioutil.Discard, "", 0),
}
go server.Serve(socketListener)
return cleanupSocket, nil
}
func cleanupSocket() {
os.RemoveAll(tempDir)
}
func buildHandler(handlers []TestRequestHandler) http.Handler {
h := http.NewServeMux()
for _, handler := range handlers {
h.HandleFunc(handler.Path, handler.Handler)
}
return h
}
......@@ -30,12 +30,19 @@ describe 'bin/gitlab-shell' do
@server = HTTPUNIXServer.new(BindAddress: tmp_socket_path)
@server.mount_proc('/api/v4/internal/discover') do |req, res|
if req.query['key_id'] == '100' ||
req.query['user_id'] == '10' ||
req.query['username'] == 'someuser'
identifier = req.query['key_id'] || req.query['username'] || req.query['user_id']
known_identifiers = %w(10 someuser 100)
if known_identifiers.include?(identifier)
res.status = 200
res.content_type = 'application/json'
res.body = '{"id":1, "name": "Some User", "username": "someuser"}'
elsif identifier == 'broken_message'
res.status = 401
res.body = '{"message": "Forbidden!"}'
elsif identifier && identifier != 'broken'
res.status = 200
res.content_type = 'application/json'
res.body = 'null'
else
res.status = 500
end
......@@ -145,11 +152,7 @@ describe 'bin/gitlab-shell' do
)
end
it_behaves_like 'results with keys' do
before do
pending
end
end
it_behaves_like 'results with keys'
it 'outputs "Only ssh allowed"' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-someuser"], env: {})
......@@ -157,6 +160,20 @@ describe 'bin/gitlab-shell' do
expect(stderr).to eq("Only ssh allowed\n")
expect(status).not_to be_success
end
it 'returns an error message when the API call fails with a message' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken_message"])
expect(stderr).to match(/Failed to get username: Forbidden!/)
expect(status).not_to be_success
end
it 'returns an error message when the API call fails without a message' do
_, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken"])
expect(stderr).to match(/Failed to get username: Internal API error \(500\)/)
expect(status).not_to be_success
end
end
def run!(args, env: {'SSH_CONNECTION' => 'fake'})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment