Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a role cache to avoid separate locking paths #6926

Merged
merged 5 commits into from
Jun 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package awsauth
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -76,6 +77,9 @@ type backend struct {
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string

// roleCache caches role entries to avoid locking headaches
roleCache *cache.Cache

resolveArnToUniqueIDFunc func(context.Context, logical.Storage, string) (string, error)
}

Expand All @@ -89,6 +93,7 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour),
tidyBlacklistCASGuard: new(uint32),
tidyWhitelistCASGuard: new(uint32),
roleCache: cache.New(cache.NoExpiration, cache.NoExpiration),
}

b.resolveArnToUniqueIDFunc = b.resolveArnToRealUniqueId
Expand Down Expand Up @@ -201,13 +206,16 @@ func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error
}

func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/client":
switch {
case key == "config/client":
b.configMutex.Lock()
defer b.configMutex.Unlock()
b.flushCachedEC2Clients()
b.flushCachedIAMClients()
b.defaultAWSAccountID = ""
case strings.HasPrefix(key, "role"):
// TODO: We could make this better
b.roleCache.Flush()
jefferai marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
4 changes: 2 additions & 2 deletions builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}

// read the created role entry
roleEntry, err := b.lockedAWSRole(context.Background(), storage, "abcd-123")
roleEntry, err := b.role(context.Background(), storage, "abcd-123")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}

// get the entry of the newly created role entry
roleEntry2, err := b.lockedAWSRole(context.Background(), storage, "ami-6789")
roleEntry2, err := b.role(context.Background(), storage, "ami-6789")
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}

// Get the entry for the role used by the instance
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -951,7 +951,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if roleName == "" {
return nil, fmt.Errorf("error retrieving role_name during renewal")
}
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1079,7 +1079,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}

// Ensure that role entry is not deleted
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, storedIdentity.Role)
roleEntry, err := b.role(ctx, req.Storage, storedIdentity.Role)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1205,7 +1205,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
roleName = entity.FriendlyName
}

roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
roleEntry, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestBackend_pathLogin_IAMHeaders(t *testing.T) {
AuthType: iamAuthType,
}

if err := b.nonLockedSetAWSRole(context.Background(), storage, testValidRoleName, roleEntry); err != nil {
if err := b.setRole(context.Background(), storage, testValidRoleName, roleEntry); err != nil {
t.Fatalf("failed to set entry: %s", err)
}

Expand Down
172 changes: 78 additions & 94 deletions builtin/credential/aws/path_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package awsauth

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/copystructure"
)

var (
Expand Down Expand Up @@ -218,82 +220,83 @@ func pathListRoles(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
entry, err := b.role(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return false, err
}
return entry != nil, nil
}

// lockedAWSRole returns the properties set on the given role. This method
// acquires the read lock before reading the role from the storage.
func (b *backend) lockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
// role fetches the role entry from cache, or loads from disk if necessary
func (b *backend) role(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}

b.roleMutex.RLock()
roleEntry, err := b.nonLockedAWSRole(ctx, s, roleName)
// we manually unlock rather than defer the unlock because we might need to grab
// a read/write lock in the upgrade path
b.roleMutex.RUnlock()
roleEntryRaw, found := b.roleCache.Get(roleName)
if found && roleEntryRaw != nil {
roleEntry, ok := roleEntryRaw.(*awsRoleEntry)
if !ok {
return nil, errors.New("could not convert role entry internally")
}
if roleEntry == nil {
return nil, errors.New("converted role entry is nil")
}
}

// Not found, or was nil
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return b.roleInternal(ctx, s, roleName)
}

// roleInternal does not perform locking, and rechecks the cache, going to disk if necessar
func (b *backend) roleInternal(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
// Check cache again now that we have the lock
roleEntryRaw, found := b.roleCache.Get(roleName)
if found && roleEntryRaw != nil {
roleEntry, ok := roleEntryRaw.(*awsRoleEntry)
if !ok {
return nil, errors.New("could not convert role entry internally")
}
if roleEntry == nil {
return nil, errors.New("converted role entry is nil")
}
}

// Fetch from storage
entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
if roleEntry == nil {
if entry == nil {
return nil, nil
}
needUpgrade, err := b.upgradeRoleEntry(ctx, s, roleEntry)

result := new(awsRoleEntry)
if err := entry.DecodeJSON(result); err != nil {
return nil, err
}

needUpgrade, err := b.upgradeRole(ctx, s, result)
if err != nil {
return nil, errwrap.Wrapf("error upgrading roleEntry: {{err}}", err)
}
if needUpgrade && (b.System().LocalMount() || !b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary|consts.ReplicationPerformanceStandby)) {
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
// Now that we have a R/W lock, we need to re-read the role entry in case it was
// written to between releasing the read lock and acquiring the write lock
roleEntry, err = b.nonLockedAWSRole(ctx, s, roleName)
if err != nil {
return nil, err
}
// somebody deleted the role, so no use in putting it back
if roleEntry == nil {
return nil, nil
if err = b.setRole(ctx, s, roleName, result); err != nil {
return nil, errwrap.Wrapf("error saving upgraded roleEntry: {{err}}", err)
}
// now re-check to see if we need to upgrade
if needUpgrade, err = b.upgradeRoleEntry(ctx, s, roleEntry); err != nil {
return nil, errwrap.Wrapf("error upgrading roleEntry: {{err}}", err)
}
if needUpgrade {
if err = b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry); err != nil {
return nil, errwrap.Wrapf("error saving upgraded roleEntry: {{err}}", err)
}
}
}
return roleEntry, nil
}

// lockedSetAWSRole creates or updates a role in the storage. This method
// acquires the write lock before creating or updating the role at the storage.
func (b *backend) lockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}

if roleEntry == nil {
return fmt.Errorf("nil role entry")
}
b.roleCache.SetDefault(roleName, result)

b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry)
return result, nil
}

// nonLockedSetAWSRole creates or updates a role in the storage. This method
// does not acquire the write lock before reading the role from the storage. If
// locking is desired, use lockedSetAWSRole instead.
func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string,
// setRole creates or updates a role in the storage. The caller must hold
// the write lock.
func (b *backend) setRole(ctx context.Context, s logical.Storage, roleName string,
roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
Expand All @@ -312,12 +315,14 @@ func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, ro
return err
}

b.roleCache.SetDefault(roleName, roleEntry)

return nil
}

// If needed, updates the role entry and returns a bool indicating if it was updated
// (and thus needs to be persisted)
func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
func (b *backend) upgradeRole(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
if roleEntry == nil {
return false, fmt.Errorf("received nil roleEntry")
}
Expand Down Expand Up @@ -421,33 +426,6 @@ func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleE
return upgraded, nil
}

// nonLockedAWSRole returns the properties set on the given role. This method
// does not acquire the read lock before reading the role from the storage. If
// locking is desired, use lockedAWSRole instead.
// This method also does NOT check to see if a role upgrade is required. It is
// the responsibility of the caller to check if a role upgrade is required and,
// if so, to upgrade the role
func (b *backend) nonLockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}

entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}

var result awsRoleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}

return &result, nil
}

// pathRoleDelete is used to delete the information registered for a given AMI ID.
func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleName := data.Get("role").(string)
Expand All @@ -458,24 +436,29 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

return nil, req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
err := req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, errwrap.Wrapf("error deleting role: {{err}}", err)
}

b.roleCache.Delete(roleName)

return nil, nil
}

// pathRoleList is used to list all the AMI IDs registered with Vault.
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()

roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}

return logical.ListResponse(roles), nil
}

// pathRoleRead is used to view the information registered for a given AMI ID.
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
roleEntry, err := b.role(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return nil, err
}
Expand All @@ -498,7 +481,10 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
b.roleMutex.Lock()
defer b.roleMutex.Unlock()

roleEntry, err := b.nonLockedAWSRole(ctx, req.Storage, roleName)
// We use the internal one here to ensure that we have fresh data and
// nobody else is concurrently modifying. This will also call the upgrade
// path on existing role entries.
roleEntry, err := b.roleInternal(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
Expand All @@ -512,16 +498,14 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
Version: currentRoleStorageVersion,
}
} else {
needUpdate, err := b.upgradeRoleEntry(ctx, req.Storage, roleEntry)
// We want to always use a copy so we aren't modifying items in the
// version in the cache while other users may be looking it up (or if
// we fail somewhere)
cp, err := copystructure.Copy(roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to update roleEntry: %v", err)), nil
}
if needUpdate {
err = b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to save upgraded roleEntry: %v", err)), nil
}
return nil, err
}
roleEntry = cp.(*awsRoleEntry)
}

// Fetch and set the bound parameters. There can't be default values
Expand Down Expand Up @@ -808,7 +792,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
}

if err := b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry); err != nil {
if err := b.setRole(ctx, req.Storage, roleName, roleEntry); err != nil {
return nil, err
}

Expand Down
Loading