Skip to content

Commit

Permalink
Use a role cache to avoid separate locking paths (#6926)
Browse files Browse the repository at this point in the history
* Use a role cache to avoid separate locking paths

Due to the various locked/nonlocked paths we had a case where we weren't
always checking for secondary status before trying to upgrade. This
broadly simplifies things by using a cache to store the current role
values (avoiding a lot of storage hits) and updating the cache on any
write, delete, or invalidation.
  • Loading branch information
jefferai authored Jun 20, 2019
1 parent 41973bb commit 4ff9001
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 113 deletions.
12 changes: 10 additions & 2 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package awsauth
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -76,6 +77,9 @@ type backend struct {
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string

// roleCache caches role entries to avoid locking headaches
roleCache *cache.Cache

resolveArnToUniqueIDFunc func(context.Context, logical.Storage, string) (string, error)
}

Expand All @@ -89,6 +93,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour),
tidyBlacklistCASGuard: new(uint32),
tidyWhitelistCASGuard: new(uint32),
roleCache: cache.New(cache.NoExpiration, cache.NoExpiration),
}

b.resolveArnToUniqueIDFunc = b.resolveArnToRealUniqueId
Expand Down Expand Up @@ -201,13 +206,16 @@ func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error
}

func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/client":
switch {
case key == "config/client":
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
case strings.HasPrefix(key, "role"):
// TODO: We could make this better
b.roleCache.Flush()
}
}

Expand Down
4 changes: 2 additions & 2 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}

// read the created role entry
roleEntry, err := b.lockedAWSRole(context.Background(), storage, "abcd-123")
roleEntry, err := b.role(context.Background(), storage, "abcd-123")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}

// get the entry of the newly created role entry
roleEntry2, err := b.lockedAWSRole(context.Background(), storage, "ami-6789")
roleEntry2, err := b.role(context.Background(), storage, "ami-6789")
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}

// Get the entry for the role used by the instance
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -951,7 +951,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if roleName == "" {
return nil, fmt.Errorf("error retrieving role_name during renewal")
}
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1079,7 +1079,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}

// Ensure that role entry is not deleted
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, storedIdentity.Role)
roleEntry, err := b.role(ctx, req.Storage, storedIdentity.Role)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1205,7 +1205,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
roleName = entity.FriendlyName
}

roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestBackend_pathLogin_IAMHeaders(t *testing.T) {
AuthType: iamAuthType,
}

if err := b.nonLockedSetAWSRole(context.Background(), storage, testValidRoleName, roleEntry); err != nil {
if err := b.setRole(context.Background(), storage, testValidRoleName, roleEntry); err != nil {
t.Fatalf("failed to set entry: %s", err)
}

Expand Down
172 changes: 78 additions & 94 deletions builtin/credential/aws/path_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package awsauth

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/copystructure"
)

var (
Expand Down Expand Up @@ -218,82 +220,83 @@ func pathListRoles(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
entry, err := b.role(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return false, err
}
return entry != nil, nil
}

// lockedAWSRole returns the properties set on the given role. This method
// acquires the read lock before reading the role from the storage.
func (b *backend) lockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
// role fetches the role entry from cache, or loads from disk if necessary
func (b *backend) role(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}

b.roleMutex.RLock()
roleEntry, err := b.nonLockedAWSRole(ctx, s, roleName)
// we manually unlock rather than defer the unlock because we might need to grab
// a read/write lock in the upgrade path
b.roleMutex.RUnlock()
roleEntryRaw, found := b.roleCache.Get(roleName)
if found && roleEntryRaw != nil {
roleEntry, ok := roleEntryRaw.(*awsRoleEntry)
if !ok {
return nil, errors.New("could not convert role entry internally")
}
if roleEntry == nil {
return nil, errors.New("converted role entry is nil")
}
}

// Not found, or was nil
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return b.roleInternal(ctx, s, roleName)
}

// roleInternal does not perform locking, and rechecks the cache, going to disk if necessar
func (b *backend) roleInternal(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
// Check cache again now that we have the lock
roleEntryRaw, found := b.roleCache.Get(roleName)
if found && roleEntryRaw != nil {
roleEntry, ok := roleEntryRaw.(*awsRoleEntry)
if !ok {
return nil, errors.New("could not convert role entry internally")
}
if roleEntry == nil {
return nil, errors.New("converted role entry is nil")
}
}

// Fetch from storage
entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
if roleEntry == nil {
if entry == nil {
return nil, nil
}
needUpgrade, err := b.upgradeRoleEntry(ctx, s, roleEntry)

result := new(awsRoleEntry)
if err := entry.DecodeJSON(result); err != nil {
return nil, err
}

needUpgrade, err := b.upgradeRole(ctx, s, result)
if err != nil {
return nil, errwrap.Wrapf("error upgrading roleEntry: {{err}}", err)
}
if needUpgrade && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary|consts.ReplicationPerformanceStandby)) {
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
// Now that we have a R/W lock, we need to re-read the role entry in case it was
// written to between releasing the read lock and acquiring the write lock
roleEntry, err = b.nonLockedAWSRole(ctx, s, roleName)
if err != nil {
return nil, err
}
// somebody deleted the role, so no use in putting it back
if roleEntry == nil {
return nil, nil
if err = b.setRole(ctx, s, roleName, result); err != nil {
return nil, errwrap.Wrapf("error saving upgraded roleEntry: {{err}}", err)
}
// now re-check to see if we need to upgrade
if needUpgrade, err = b.upgradeRoleEntry(ctx, s, roleEntry); err != nil {
return nil, errwrap.Wrapf("error upgrading roleEntry: {{err}}", err)
}
if needUpgrade {
if err = b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry); err != nil {
return nil, errwrap.Wrapf("error saving upgraded roleEntry: {{err}}", err)
}
}
}
return roleEntry, nil
}

// lockedSetAWSRole creates or updates a role in the storage. This method
// acquires the write lock before creating or updating the role at the storage.
func (b *backend) lockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}

if roleEntry == nil {
return fmt.Errorf("nil role entry")
}
b.roleCache.SetDefault(roleName, result)

b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry)
return result, nil
}

// nonLockedSetAWSRole creates or updates a role in the storage. This method
// does not acquire the write lock before reading the role from the storage. If
// locking is desired, use lockedSetAWSRole instead.
func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string,
// setRole creates or updates a role in the storage. The caller must hold
// the write lock.
func (b *backend) setRole(ctx context.Context, s logical.Storage, roleName string,
roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
Expand All @@ -312,12 +315,14 @@ func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, ro
return err
}

b.roleCache.SetDefault(roleName, roleEntry)

return nil
}

// If needed, updates the role entry and returns a bool indicating if it was updated
// (and thus needs to be persisted)
func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
func (b *backend) upgradeRole(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
if roleEntry == nil {
return false, fmt.Errorf("received nil roleEntry")
}
Expand Down Expand Up @@ -421,33 +426,6 @@ func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleE
return upgraded, nil
}

// nonLockedAWSRole returns the properties set on the given role. This method
// does not acquire the read lock before reading the role from the storage. If
// locking is desired, use lockedAWSRole instead.
// This method also does NOT check to see if a role upgrade is required. It is
// the responsibility of the caller to check if a role upgrade is required and,
// if so, to upgrade the role
func (b *backend) nonLockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}

entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}

var result awsRoleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}

return &result, nil
}

// pathRoleDelete is used to delete the information registered for a given AMI ID.
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleName := data.Get("role").(string)
Expand All @@ -458,24 +436,29 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return nil, req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
err := req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, errwrap.Wrapf("error deleting role: {{err}}", err)
}

b.roleCache.Delete(roleName)

return nil, nil
}

// pathRoleList is used to list all the AMI IDs registered with Vault.
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()

roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}

return logical.ListResponse(roles), nil
}

// pathRoleRead is used to view the information registered for a given AMI ID.
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
roleEntry, err := b.role(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return nil, err
}
Expand All @@ -498,7 +481,10 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

roleEntry, err := b.nonLockedAWSRole(ctx, req.Storage, roleName)
// We use the internal one here to ensure that we have fresh data and
// nobody else is concurrently modifying. This will also call the upgrade
// path on existing role entries.
roleEntry, err := b.roleInternal(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand All @@ -512,16 +498,14 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
Version: currentRoleStorageVersion,
}
} else {
needUpdate, err := b.upgradeRoleEntry(ctx, req.Storage, roleEntry)
// We want to always use a copy so we aren't modifying items in the
// version in the cache while other users may be looking it up (or if
// we fail somewhere)
cp, err := copystructure.Copy(roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to update roleEntry: %v", err)), nil
}
if needUpdate {
err = b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to save upgraded roleEntry: %v", err)), nil
}
return nil, err
}
roleEntry = cp.(*awsRoleEntry)
}

// Fetch and set the bound parameters. There can't be default values
Expand Down Expand Up @@ -808,7 +792,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
}

if err := b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry); err != nil {
if err := b.setRole(ctx, req.Storage, roleName, roleEntry); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 4ff9001

Please sign in to comment.