diff --git a/vault/ha.go b/vault/ha.go index 7cfb5139a72d..07e0b3da2afb 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strings" + "sync" "sync/atomic" "time" @@ -638,35 +639,48 @@ func (c *Core) waitForLeadership(newLeaderCh chan func(), manualStepDownCh, stop } } -// grabLockOrStop returns true if we failed to get the lock before stopCh -// was closed. Returns false if the lock was obtained, in which case it's -// the caller's responsibility to unlock it. +// grabLockOrStop returns stopped=false if the lock is acquired. Returns +// stopped=true if the lock is not acquired, because stopCh was closed. If the +// lock was acquired (stopped=false) then it's up to the caller to unlock. func grabLockOrStop(lockFunc, unlockFunc func(), stopCh chan struct{}) (stopped bool) { - // Grab the lock as we need it for cluster setup, which needs to happen - // before advertising; - lockGrabbedCh := make(chan struct{}) + // lock protects these variables which are shared by parent and child. + var lock sync.Mutex + parentWaiting := true + locked := false + + // doneCh is closed when the child goroutine is done. + doneCh := make(chan struct{}) go func() { - // Grab the lock + defer close(doneCh) lockFunc() - // If stopCh has been closed, which only happens while the - // stateLock is held, we have actually terminated, so we just - // instantly give up the lock, otherwise we notify that it's ready - // for consumption - select { - case <-stopCh: + + // The parent goroutine may or may not be waiting. + lock.Lock() + defer lock.Unlock() + if !parentWaiting { unlockFunc() - default: - close(lockGrabbedCh) + } else { + locked = true } }() + stop := false select { case <-stopCh: - return true - case <-lockGrabbedCh: - // We now have the lock and can use it + stop = true + case <-doneCh: } + // The child goroutine may not have acquired the lock yet. + lock.Lock() + defer lock.Unlock() + parentWaiting = false + if stop { + if locked { + unlockFunc() + } + return true + } return false } diff --git a/vault/ha_test.go b/vault/ha_test.go new file mode 100644 index 000000000000..9e753b2c56b5 --- /dev/null +++ b/vault/ha_test.go @@ -0,0 +1,84 @@ +package vault + +import ( + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestGrabLockOrStopped is a non-deterministic test to detect deadlocks in the +// grabLockOrStopped function. This test starts a bunch of workers which +// continually lock/unlock and rlock/runlock the same RWMutex. Each worker also +// starts a goroutine which closes the stop channel 1/2 the time, which races +// with acquisition of the lock. +func TestGrabLockOrStop(t *testing.T) { + // Stop the test early if we deadlock. + const ( + workers = 100 + testDuration = time.Second + testTimeout = 10*testDuration + ) + done := make(chan struct{}) + defer close(done) + var lockCount int64 + go func() { + select{ + case <-done: + case <-time.After(testTimeout): + panic(fmt.Sprintf("deadlock after %d lock count", + atomic.LoadInt64(&lockCount))) + } + }() + + // lock is locked/unlocked and rlocked/runlocked concurrently. + var lock sync.RWMutex + start := time.Now() + + // workerWg is used to wait until all workers exit. + var workerWg sync.WaitGroup + workerWg.Add(workers) + + // Start a bunch of worker goroutines. + for g := 0; g < workers; g++ { + g := g + go func() { + defer workerWg.Done() + for time.Now().Sub(start) < testDuration { + stop := make(chan struct{}) + + // closerWg waits until the closer goroutine exits before we do + // another iteration. This makes sure goroutines don't pile up. + var closerWg sync.WaitGroup + closerWg.Add(1) + go func() { + defer closerWg.Done() + // Close the stop channel half the time. + if rand.Int() % 2 == 0 { + close(stop) + } + }() + + // Half the goroutines lock/unlock and the other half rlock/runlock. + if g % 2 == 0 { + if !grabLockOrStop(lock.Lock, lock.Unlock, stop) { + lock.Unlock() + } + } else { + if !grabLockOrStop(lock.RLock, lock.RUnlock, stop) { + lock.RUnlock() + } + } + + closerWg.Wait() + + // This lets us know how many lock/unlock and rlock/runlock have + // happened if there's a deadlock. + atomic.AddInt64(&lockCount, 1) + } + }() + } + workerWg.Wait() +} \ No newline at end of file