Commit 8662516a authored by Jacob Vosmaer's avatar Jacob Vosmaer Committed by Nick Thomas

Simplify config handling in main()

parent ecec37a2
#!/bin/sh #!/bin/sh
set -eu
IMPORT_RESULT=$(goimports -e -local "gitlab.com/gitlab-org/gitlab-workhorse" -l "$@") IMPORT_RESULT=$(goimports -e -local "gitlab.com/gitlab-org/gitlab-workhorse" -l "$@")
if [ -n "${IMPORT_RESULT}" ]; then if [ -n "${IMPORT_RESULT}" ]; then
......
---
title: Simplify config handling in main()
merge_request: 634
author:
type: other
package main
import (
"flag"
"io"
"io/ioutil"
"net/url"
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
)
func TestConfigFile(t *testing.T) {
f, err := ioutil.TempFile("", "workhorse-config-test")
require.NoError(t, err)
defer os.Remove(f.Name())
data := `
[redis]
password = "redis password"
[object_storage]
provider = "test provider"
[image_resizer]
max_scaler_procs = 123
`
_, err = io.WriteString(f, data)
require.NoError(t, err)
require.NoError(t, f.Close())
_, cfg, err := buildConfig("test", []string{"-config", f.Name()})
require.NoError(t, err, "build config")
// These are integration tests: we want to see that each section in the
// config file ends up in the config struct. We do not test all the
// fields in each section; that should happen in the tests of the
// internal/config package.
require.Equal(t, "redis password", cfg.Redis.Password)
require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider)
require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs")
}
func TestConfigErrorHelp(t *testing.T) {
for _, f := range []string{"-h", "-help"} {
t.Run(f, func(t *testing.T) {
_, _, err := buildConfig("test", []string{f})
require.Equal(t, alreadyPrintedError{flag.ErrHelp}, err)
})
}
}
func TestConfigError(t *testing.T) {
for _, arg := range []string{"-foobar", "foobar"} {
t.Run(arg, func(t *testing.T) {
_, _, err := buildConfig("test", []string{arg})
require.Error(t, err)
require.IsType(t, alreadyPrintedError{}, err)
})
}
}
func TestConfigDefaults(t *testing.T) {
boot, cfg, err := buildConfig("test", nil)
require.NoError(t, err, "build config")
expectedBoot := &bootConfig{
secretPath: "./.gitlab_workhorse_secret",
listenAddr: "localhost:8181",
listenNetwork: "tcp",
logFormat: "text",
}
require.Equal(t, expectedBoot, boot)
expectedCfg := &config.Config{
Backend: upstream.DefaultBackend,
CableBackend: upstream.DefaultBackend,
Version: "(unknown version)",
DocumentRoot: "public",
ProxyHeadersTimeout: 5 * time.Minute,
APIQueueTimeout: queueing.DefaultTimeout,
APICILongPollingDuration: 50 * time.Nanosecond, // TODO this is meant to be 50*time.Second but it has been wrong for ages
ImageResizerConfig: config.DefaultImageResizerConfig,
}
require.Equal(t, expectedCfg, cfg)
}
func TestConfigFlagParsing(t *testing.T) {
backendURL, err := url.Parse("http://localhost:1234")
require.NoError(t, err)
cableURL, err := url.Parse("http://localhost:5678")
require.NoError(t, err)
args := []string{
"-version",
"-secretPath", "secret path",
"-listenAddr", "listen addr",
"-listenNetwork", "listen network",
"-listenUmask", "123",
"-pprofListenAddr", "pprof listen addr",
"-prometheusListenAddr", "prometheus listen addr",
"-logFile", "log file",
"-logFormat", "log format",
"-documentRoot", "document root",
"-developmentMode",
"-authBackend", backendURL.String(),
"-authSocket", "auth socket",
"-cableBackend", cableURL.String(),
"-cableSocket", "cable socket",
"-proxyHeadersTimeout", "10m",
"-apiLimit", "234",
"-apiQueueLimit", "345",
"-apiQueueDuration", "123s",
"-apiCiLongPollingDuration", "234s",
"-propagateCorrelationID",
}
boot, cfg, err := buildConfig("test", args)
require.NoError(t, err, "build config")
expectedBoot := &bootConfig{
secretPath: "secret path",
listenAddr: "listen addr",
listenNetwork: "listen network",
listenUmask: 123,
pprofListenAddr: "pprof listen addr",
prometheusListenAddr: "prometheus listen addr",
logFile: "log file",
logFormat: "log format",
printVersion: true,
}
require.Equal(t, expectedBoot, boot)
expectedCfg := &config.Config{
DocumentRoot: "document root",
DevelopmentMode: true,
Backend: backendURL,
Socket: "auth socket",
CableBackend: cableURL,
CableSocket: "cable socket",
Version: "(unknown version)",
ProxyHeadersTimeout: 10 * time.Minute,
APILimit: 234,
APIQueueLimit: 345,
APIQueueTimeout: 123 * time.Second,
APICILongPollingDuration: 234 * time.Second,
PropagateCorrelationID: true,
ImageResizerConfig: config.DefaultImageResizerConfig,
}
require.Equal(t, expectedCfg, cfg)
}
...@@ -115,11 +115,10 @@ var DefaultImageResizerConfig = &ImageResizerConfig{ ...@@ -115,11 +115,10 @@ var DefaultImageResizerConfig = &ImageResizerConfig{
MaxFilesize: DefaultImageResizerMaxFilesize, MaxFilesize: DefaultImageResizerMaxFilesize,
} }
// LoadConfig from a file func LoadConfig(data string) (*Config, error) {
func LoadConfig(filename string) (*Config, error) {
cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig} cfg := &Config{ImageResizerConfig: DefaultImageResizerConfig}
if _, err := toml.DecodeFile(filename, cfg); err != nil { if _, err := toml.Decode(data, cfg); err != nil {
return nil, err return nil, err
} }
......
package config package config
import ( import (
"io/ioutil"
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -20,13 +18,20 @@ azure_storage_access_key = "deadbeef" ...@@ -20,13 +18,20 @@ azure_storage_access_key = "deadbeef"
func TestLoadEmptyConfig(t *testing.T) { func TestLoadEmptyConfig(t *testing.T) {
config := `` config := ``
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.Nil(t, cfg.ObjectStorageCredentials) expected := Config{
ImageResizerConfig: &ImageResizerConfig{
MaxScalerProcs: DefaultImageResizerMaxScalerProcs,
MaxFilesize: DefaultImageResizerMaxFilesize,
},
}
err := cfg.RegisterGoCloudURLOpeners() require.Equal(t, expected, *cfg)
require.NoError(t, err)
require.Nil(t, cfg.ObjectStorageCredentials)
require.NoError(t, cfg.RegisterGoCloudURLOpeners())
} }
func TestLoadObjectStorageConfig(t *testing.T) { func TestLoadObjectStorageConfig(t *testing.T) {
...@@ -39,8 +44,8 @@ aws_access_key_id = "minio" ...@@ -39,8 +44,8 @@ aws_access_key_id = "minio"
aws_secret_access_key = "gdk-minio" aws_secret_access_key = "gdk-minio"
` `
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
...@@ -56,8 +61,8 @@ aws_secret_access_key = "gdk-minio" ...@@ -56,8 +61,8 @@ aws_secret_access_key = "gdk-minio"
} }
func TestRegisterGoCloudURLOpeners(t *testing.T) { func TestRegisterGoCloudURLOpeners(t *testing.T) {
tmpFile, cfg := loadTempConfig(t, azureConfig) cfg, err := LoadConfig(azureConfig)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
...@@ -72,30 +77,13 @@ func TestRegisterGoCloudURLOpeners(t *testing.T) { ...@@ -72,30 +77,13 @@ func TestRegisterGoCloudURLOpeners(t *testing.T) {
require.Equal(t, expected, *cfg.ObjectStorageCredentials) require.Equal(t, expected, *cfg.ObjectStorageCredentials)
require.Nil(t, cfg.ObjectStorageConfig.URLMux) require.Nil(t, cfg.ObjectStorageConfig.URLMux)
err := cfg.RegisterGoCloudURLOpeners() require.NoError(t, cfg.RegisterGoCloudURLOpeners())
require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageConfig.URLMux) require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob")) require.True(t, cfg.ObjectStorageConfig.URLMux.ValidBucketScheme("azblob"))
require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes()) require.Equal(t, []string{"azblob"}, cfg.ObjectStorageConfig.URLMux.BucketSchemes())
} }
func TestLoadDefaultConfig(t *testing.T) {
config := ``
tmpFile, cfg := loadTempConfig(t, config)
defer os.Remove(tmpFile.Name())
expected := Config{
ImageResizerConfig: &ImageResizerConfig{
MaxScalerProcs: DefaultImageResizerMaxScalerProcs,
MaxFilesize: DefaultImageResizerMaxFilesize,
},
}
require.Equal(t, expected, *cfg)
}
func TestLoadImageResizerConfig(t *testing.T) { func TestLoadImageResizerConfig(t *testing.T) {
config := ` config := `
[image_resizer] [image_resizer]
...@@ -103,8 +91,8 @@ max_scaler_procs = 200 ...@@ -103,8 +91,8 @@ max_scaler_procs = 200
max_filesize = 350000 max_filesize = 350000
` `
tmpFile, cfg := loadTempConfig(t, config) cfg, err := LoadConfig(config)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config") require.NotNil(t, cfg.ImageResizerConfig, "Expected image resizer config")
...@@ -115,16 +103,3 @@ max_filesize = 350000 ...@@ -115,16 +103,3 @@ max_filesize = 350000
require.Equal(t, expected, *cfg.ImageResizerConfig) require.Equal(t, expected, *cfg.ImageResizerConfig)
} }
func loadTempConfig(t *testing.T, config string) (f *os.File, cfg *Config) {
tmpFile, err := ioutil.TempFile(os.TempDir(), "test-")
require.NoError(t, err)
_, err = tmpFile.Write([]byte(config))
require.NoError(t, err)
cfg, err = LoadConfig(tmpFile.Name())
require.NoError(t, err)
return tmpFile, cfg
}
...@@ -3,7 +3,6 @@ package config ...@@ -3,7 +3,6 @@ package config
import ( import (
"context" "context"
"net/url" "net/url"
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -11,13 +10,12 @@ import ( ...@@ -11,13 +10,12 @@ import (
) )
func TestURLOpeners(t *testing.T) { func TestURLOpeners(t *testing.T) {
tmpFile, cfg := loadTempConfig(t, azureConfig) cfg, err := LoadConfig(azureConfig)
defer os.Remove(tmpFile.Name()) require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials") require.NotNil(t, cfg.ObjectStorageCredentials, "Expected object storage credentials")
err := cfg.RegisterGoCloudURLOpeners() require.NoError(t, cfg.RegisterGoCloudURLOpeners())
require.NoError(t, err)
require.NotNil(t, cfg.ObjectStorageConfig.URLMux) require.NotNil(t, cfg.ObjectStorageConfig.URLMux)
tests := []struct { tests := []struct {
......
...@@ -18,26 +18,20 @@ const ( ...@@ -18,26 +18,20 @@ const (
noneLogType = "none" noneLogType = "none"
) )
type logConfiguration struct { func startLogging(file string, format string) (io.Closer, error) {
logFile string
logFormat string
}
func startLogging(config logConfiguration) (io.Closer, error) {
// Golog always goes to stderr // Golog always goes to stderr
goLog.SetOutput(os.Stderr) goLog.SetOutput(os.Stderr)
logFile := config.logFile if file == "" {
if logFile == "" { file = "stderr"
logFile = "stderr"
} }
switch config.logFormat { switch format {
case noneLogType: case noneLogType:
return logkit.Initialize(logkit.WithWriter(ioutil.Discard)) return logkit.Initialize(logkit.WithWriter(ioutil.Discard))
case jsonLogFormat: case jsonLogFormat:
return logkit.Initialize( return logkit.Initialize(
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
logkit.WithFormatter("json"), logkit.WithFormatter("json"),
) )
case textLogFormat: case textLogFormat:
...@@ -48,23 +42,22 @@ func startLogging(config logConfiguration) (io.Closer, error) { ...@@ -48,23 +42,22 @@ func startLogging(config logConfiguration) (io.Closer, error) {
) )
case structuredFormat: case structuredFormat:
return logkit.Initialize( return logkit.Initialize(
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
logkit.WithFormatter("color"), logkit.WithFormatter("color"),
) )
} }
return nil, fmt.Errorf("unknown logFormat: %v", config.logFormat) return nil, fmt.Errorf("unknown logFormat: %v", format)
} }
// In text format, we use a separate logger for access logs // In text format, we use a separate logger for access logs
func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) { func getAccessLogger(file string, format string) (*log.Logger, io.Closer, error) {
if config.logFormat != "text" { if format != "text" {
return log.StandardLogger(), ioutil.NopCloser(nil), nil return log.StandardLogger(), ioutil.NopCloser(nil), nil
} }
logFile := config.logFile if file == "" {
if logFile == "" { file = "stderr"
logFile = "stderr"
} }
accessLogger := log.New() accessLogger := log.New()
...@@ -72,7 +65,7 @@ func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) { ...@@ -72,7 +65,7 @@ func getAccessLogger(config logConfiguration) (*log.Logger, io.Closer, error) {
closer, err := logkit.Initialize( closer, err := logkit.Initialize(
logkit.WithLogger(accessLogger), // Configure `accessLogger` logkit.WithLogger(accessLogger), // Configure `accessLogger`
logkit.WithFormatter("combined"), // Use the combined formatter logkit.WithFormatter("combined"), // Use the combined formatter
logkit.WithOutputName(logFile), logkit.WithOutputName(file),
) )
return accessLogger, closer, err return accessLogger, closer, err
......
This diff is collapsed.
...@@ -38,6 +38,9 @@ import ( ...@@ -38,6 +38,9 @@ import (
const scratchDir = "testdata/scratch" const scratchDir = "testdata/scratch"
const testRepoRoot = "testdata/data" const testRepoRoot = "testdata/data"
const testDocumentRoot = "testdata/public" const testDocumentRoot = "testdata/public"
var absDocumentRoot string
const testRepo = "group/test.git" const testRepo = "group/test.git"
const testProject = "group/test" const testProject = "group/test"
...@@ -183,7 +186,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) { ...@@ -183,7 +186,7 @@ func TestAllowedPublicUploadsFile(t *testing.T) {
proxied := false proxied := false
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
proxied = true proxied = true
w.Header().Add("X-Sendfile", *documentRoot+r.URL.Path) w.Header().Add("X-Sendfile", absDocumentRoot+r.URL.Path)
w.WriteHeader(200) w.WriteHeader(200)
}) })
defer ts.Close() defer ts.Close()
...@@ -577,11 +580,11 @@ func setupStaticFile(fpath, content string) error { ...@@ -577,11 +580,11 @@ func setupStaticFile(fpath, content string) error {
if err != nil { if err != nil {
return err return err
} }
*documentRoot = path.Join(cwd, testDocumentRoot) absDocumentRoot = path.Join(cwd, testDocumentRoot)
if err := os.MkdirAll(path.Join(*documentRoot, path.Dir(fpath)), 0755); err != nil { if err := os.MkdirAll(path.Join(absDocumentRoot, path.Dir(fpath)), 0755); err != nil {
return err return err
} }
staticFile := path.Join(*documentRoot, fpath) staticFile := path.Join(absDocumentRoot, fpath)
return ioutil.WriteFile(staticFile, []byte(content), 0666) return ioutil.WriteFile(staticFile, []byte(content), 0666)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment