Commit 0d9ae330 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Merge branch 'sh-workhorse-graceful-shutdown' into 'master'

Support graceful shutdown of Workhorse connections

See merge request gitlab-org/gitlab!62701
parents 62d10e93 73f323b0
# alt_document_root = '/home/git/public/assets' # alt_document_root = '/home/git/public/assets'
# shutdown_timeout = "60s"
[redis] [redis]
URL = "unix:/home/git/gitlab/redis/redis.socket" URL = "unix:/home/git/gitlab/redis/redis.socket"
......
...@@ -16,12 +16,20 @@ import ( ...@@ -16,12 +16,20 @@ import (
"gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream" "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
) )
func TestDefaultConfig(t *testing.T) {
_, cfg, err := buildConfig("test", []string{"-config", "/dev/null"})
require.NoError(t, err, "build config")
require.Equal(t, 0*time.Second, cfg.ShutdownTimeout.Duration)
}
func TestConfigFile(t *testing.T) { func TestConfigFile(t *testing.T) {
f, err := ioutil.TempFile("", "workhorse-config-test") f, err := ioutil.TempFile("", "workhorse-config-test")
require.NoError(t, err) require.NoError(t, err)
defer os.Remove(f.Name()) defer os.Remove(f.Name())
data := ` data := `
shutdown_timeout = "60s"
[redis] [redis]
password = "redis password" password = "redis password"
[object_storage] [object_storage]
...@@ -43,6 +51,7 @@ max_scaler_procs = 123 ...@@ -43,6 +51,7 @@ max_scaler_procs = 123
require.Equal(t, "redis password", cfg.Redis.Password) require.Equal(t, "redis password", cfg.Redis.Password)
require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider) require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider)
require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs") require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs")
require.Equal(t, 60*time.Second, cfg.ShutdownTimeout.Duration)
} }
func TestConfigErrorHelp(t *testing.T) { func TestConfigErrorHelp(t *testing.T) {
......
...@@ -28,7 +28,7 @@ type TomlDuration struct { ...@@ -28,7 +28,7 @@ type TomlDuration struct {
time.Duration time.Duration
} }
func (d *TomlDuration) UnmarshalTest(text []byte) error { func (d *TomlDuration) UnmarshalText(text []byte) error {
temp, err := time.ParseDuration(string(text)) temp, err := time.ParseDuration(string(text))
d.Duration = temp d.Duration = temp
return err return err
...@@ -103,6 +103,7 @@ type Config struct { ...@@ -103,6 +103,7 @@ type Config struct {
PropagateCorrelationID bool `toml:"-"` PropagateCorrelationID bool `toml:"-"`
ImageResizerConfig ImageResizerConfig `toml:"image_resizer"` ImageResizerConfig ImageResizerConfig `toml:"image_resizer"`
AltDocumentRoot string `toml:"alt_document_root"` AltDocumentRoot string `toml:"alt_document_root"`
ShutdownTimeout TomlDuration `toml:"shutdown_timeout"`
} }
var DefaultImageResizerConfig = ImageResizerConfig{ var DefaultImageResizerConfig = ImageResizerConfig{
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
var ( var (
keyWatcher = make(map[string][]chan string) keyWatcher = make(map[string][]chan string)
keyWatcherMutex sync.Mutex keyWatcherMutex sync.Mutex
shutdown = make(chan struct{})
redisReconnectTimeout = backoff.Backoff{ redisReconnectTimeout = backoff.Backoff{
//These are the defaults //These are the defaults
Min: 100 * time.Millisecond, Min: 100 * time.Millisecond,
...@@ -112,6 +113,20 @@ func Process() { ...@@ -112,6 +113,20 @@ func Process() {
} }
} }
func Shutdown() {
log.Info("keywatcher: shutting down")
keyWatcherMutex.Lock()
defer keyWatcherMutex.Unlock()
select {
case <-shutdown:
// already closed
default:
close(shutdown)
}
}
func notifyChanWatchers(key, value string) { func notifyChanWatchers(key, value string) {
keyWatcherMutex.Lock() keyWatcherMutex.Lock()
defer keyWatcherMutex.Unlock() defer keyWatcherMutex.Unlock()
...@@ -182,6 +197,9 @@ func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) ...@@ -182,6 +197,9 @@ func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error)
} }
select { select {
case <-shutdown:
log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown")
return WatchKeyStatusNoChange, nil
case currentValue := <-kw.Chan: case currentValue := <-kw.Chan:
if currentValue == "" { if currentValue == "" {
return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
......
...@@ -160,3 +160,58 @@ func TestWatchKeyMassivelyParallel(t *testing.T) { ...@@ -160,3 +160,58 @@ func TestWatchKeyMassivelyParallel(t *testing.T) {
processMessages(runTimes, "somethingelse") processMessages(runTimes, "somethingelse")
wg.Wait() wg.Wait()
} }
func TestShutdown(t *testing.T) {
conn, td := setupMockPool()
defer td()
defer func() { shutdown = make(chan struct{}) }()
conn.Command("GET", runnerKey).Expect("something")
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
val, err := WatchKey(runnerKey, "something", 10*time.Second)
require.NoError(t, err, "Expected no error")
require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change")
wg.Done()
}()
go func() {
for countWatchers(runnerKey) == 0 {
time.Sleep(time.Millisecond)
}
require.Equal(t, 1, countWatchers(runnerKey))
Shutdown()
wg.Done()
}()
wg.Wait()
for countWatchers(runnerKey) == 1 {
time.Sleep(time.Millisecond)
}
require.Equal(t, 0, countWatchers(runnerKey))
// Adding a key after the shutdown should result in an immediate response
var val WatchKeyStatus
var err error
done := make(chan struct{})
go func() {
val, err = WatchKey(runnerKey, "something", 10*time.Second)
close(done)
}()
select {
case <-done:
require.NoError(t, err, "Expected no error")
require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change")
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for WatchKey")
}
}
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -8,6 +9,7 @@ import ( ...@@ -8,6 +9,7 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal"
"syscall" "syscall"
"time" "time"
...@@ -144,6 +146,7 @@ func buildConfig(arg0 string, args []string) (*bootConfig, *config.Config, error ...@@ -144,6 +146,7 @@ func buildConfig(arg0 string, args []string) (*bootConfig, *config.Config, error
cfg.ObjectStorageCredentials = cfgFromFile.ObjectStorageCredentials cfg.ObjectStorageCredentials = cfgFromFile.ObjectStorageCredentials
cfg.ImageResizerConfig = cfgFromFile.ImageResizerConfig cfg.ImageResizerConfig = cfgFromFile.ImageResizerConfig
cfg.AltDocumentRoot = cfgFromFile.AltDocumentRoot cfg.AltDocumentRoot = cfgFromFile.AltDocumentRoot
cfg.ShutdownTimeout = cfgFromFile.ShutdownTimeout
return boot, cfg, nil return boot, cfg, nil
} }
...@@ -225,7 +228,22 @@ func run(boot bootConfig, cfg config.Config) error { ...@@ -225,7 +228,22 @@ func run(boot bootConfig, cfg config.Config) error {
up := wrapRaven(upstream.NewUpstream(cfg, accessLogger)) up := wrapRaven(upstream.NewUpstream(cfg, accessLogger))
go func() { finalErrors <- http.Serve(listener, up) }() done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
return <-finalErrors server := http.Server{Handler: up}
go func() { finalErrors <- server.Serve(listener) }()
select {
case err := <-finalErrors:
return err
case sig := <-done:
log.WithFields(log.Fields{"shutdown_timeout_s": cfg.ShutdownTimeout.Duration.Seconds(), "signal": sig.String()}).Infof("shutdown initiated")
ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background
defer cancel()
redis.Shutdown()
return server.Shutdown(ctx)
}
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment