From 53f6f604c305d63e3a899eb9bea1b5def0c63a28 Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Wed, 1 May 2019 11:35:08 -0700 Subject: [PATCH 1/2] Fix Okta auth to allow group names containing slashes This PR also adds CollectKeysPrefix which allows a more memory efficient key scan for those cases where the result is immediately filtered by prefix. --- builtin/credential/ldap/path_groups.go | 6 +- builtin/credential/ldap/path_users.go | 12 +-- builtin/credential/okta/path_groups.go | 24 ++++- builtin/credential/okta/path_groups_test.go | 108 ++++++++++++++++++++ sdk/logical/storage.go | 15 ++- sdk/logical/storage_test.go | 86 ++++++++++++++++ 6 files changed, 232 insertions(+), 19 deletions(-) create mode 100644 builtin/credential/okta/path_groups_test.go create mode 100644 sdk/logical/storage_test.go diff --git a/builtin/credential/ldap/path_groups.go b/builtin/credential/ldap/path_groups.go index 9840f4320bda..e11810a29912 100644 --- a/builtin/credential/ldap/path_groups.go +++ b/builtin/credential/ldap/path_groups.go @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path { return &framework.Path{ Pattern: `groups/(?P.+)`, Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ + "name": { Type: framework.TypeString, Description: "Name of the LDAP group.", }, - "policies": &framework.FieldSchema{ + "policies": { Type: framework.TypeCommaStringSlice, Description: "Comma-separated list of policies associated to the group.", }, @@ -132,7 +132,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f } func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - keys, err := logical.CollectKeys(ctx, req.Storage) + keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "group/") if err != nil { return nil, err } diff --git a/builtin/credential/ldap/path_users.go b/builtin/credential/ldap/path_users.go index 20d6a95fd3cf..922191c96a2f 100644 --- a/builtin/credential/ldap/path_users.go +++ b/builtin/credential/ldap/path_users.go @@ -148,18 +148,14 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr } func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - keys, err := logical.CollectKeys(ctx, req.Storage) + keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "user/") if err != nil { return nil, err } - retKeys := make([]string, 0) - for _, key := range keys { - if strings.HasPrefix(key, "user/") && !strings.HasPrefix(key, "/") { - retKeys = append(retKeys, strings.TrimPrefix(key, "user/")) - } + for i := range keys { + keys[i] = strings.TrimPrefix(keys[i], "user/") } - return logical.ListResponse(retKeys), nil - + return logical.ListResponse(keys), nil } type UserEntry struct { diff --git a/builtin/credential/okta/path_groups.go b/builtin/credential/okta/path_groups.go index 05fcaee7d123..e0a849c0292d 100644 --- a/builtin/credential/okta/path_groups.go +++ b/builtin/credential/okta/path_groups.go @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path { return &framework.Path{ Pattern: `groups/(?P.+)`, Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ + "name": { Type: framework.TypeString, Description: "Name of the Okta group.", }, - "policies": &framework.FieldSchema{ + "policies": { Type: framework.TypeCommaStringSlice, Description: "Comma-separated list of policies associated to the group.", }, @@ -57,10 +57,12 @@ func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*Grou return nil, "", err } if entry == nil { - entries, err := s.List(ctx, "group/") + entries, err := groupList(ctx, s) if err != nil { return nil, "", err + } + for _, groupName := range entries { if strings.EqualFold(groupName, n) { entry, err = s.Get(ctx, "group/"+groupName) @@ -157,13 +159,27 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f } func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - groups, err := req.Storage.List(ctx, "group/") + groups, err := groupList(ctx, req.Storage) if err != nil { return nil, err } + return logical.ListResponse(groups), nil } +func groupList(ctx context.Context, s logical.Storage) ([]string, error) { + keys, err := logical.CollectKeysPrefix(ctx, s, "group/") + if err != nil { + return nil, err + } + + for i := range keys { + keys[i] = strings.TrimPrefix(keys[i], "group/") + } + + return keys, nil +} + type GroupEntry struct { Policies []string } diff --git a/builtin/credential/okta/path_groups_test.go b/builtin/credential/okta/path_groups_test.go new file mode 100644 index 000000000000..84253f379fd8 --- /dev/null +++ b/builtin/credential/okta/path_groups_test.go @@ -0,0 +1,108 @@ +package okta + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/go-test/deep" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" +) + +func TestGroupsList(t *testing.T) { + b, storage := getBackend(t) + + groups := []string{ + "%20\\", + "foo", + "zfoo", + "🙂", + "foo/nested", + "foo/even/more/nested", + } + + for _, group := range groups { + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "groups/" + group, + Storage: storage, + Data: map[string]interface{}{ + "policies": []string{group + "_a", group + "_b"}, + }, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + } + + for _, group := range groups { + for _, upper := range []bool{false, true} { + groupPath := group + if upper { + groupPath = strings.ToUpper(group) + } + req := &logical.Request{ + Operation: logical.ReadOperation, + Path: "groups/" + groupPath, + Storage: storage, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if resp == nil { + t.Fatal("unexpected nil response") + } + + expected := []string{group + "_a", group + "_b"} + + if diff := deep.Equal(resp.Data["policies"].([]string), expected); diff != nil { + t.Fatal(diff) + } + } + } + + req := &logical.Request{ + Operation: logical.ListOperation, + Path: "groups", + Storage: storage, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if diff := deep.Equal(resp.Data["keys"].([]string), groups); diff != nil { + t.Fatal(diff) + } +} + +func getBackend(t *testing.T) (logical.Backend, logical.Storage) { + defaultLeaseTTLVal := time.Hour * 12 + maxLeaseTTLVal := time.Hour * 24 + + config := &logical.BackendConfig{ + Logger: logging.NewVaultLogger(log.Trace), + + System: &logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLVal, + MaxLeaseTTLVal: maxLeaseTTLVal, + }, + StorageView: &logical.InmemStorage{}, + } + b, err := Factory(context.Background(), config) + if err != nil { + t.Fatalf("unable to create backend: %v", err) + } + + return b, config.StorageView +} diff --git a/sdk/logical/storage.go b/sdk/logical/storage.go index 15f480978c1c..0db596a84479 100644 --- a/sdk/logical/storage.go +++ b/sdk/logical/storage.go @@ -86,17 +86,24 @@ func ScanView(ctx context.Context, view ClearableView, cb func(path string)) err // CollectKeys is used to collect all the keys in a view func CollectKeys(ctx context.Context, view ClearableView) ([]string, error) { - // Accumulate the keys - var existing []string + return CollectKeysPrefix(ctx, view, "") +} + +// CollectKeysPrefix is used to collect all the keys in a view with a given prefix string +func CollectKeysPrefix(ctx context.Context, view ClearableView, prefix string) ([]string, error) { + var keys []string + cb := func(path string) { - existing = append(existing, path) + if strings.HasPrefix(path, prefix) { + keys = append(keys, path) + } } // Scan for all the keys if err := ScanView(ctx, view, cb); err != nil { return nil, err } - return existing, nil + return keys, nil } // ClearView is used to delete all the keys in a view diff --git a/sdk/logical/storage_test.go b/sdk/logical/storage_test.go new file mode 100644 index 000000000000..53a26eb20cac --- /dev/null +++ b/sdk/logical/storage_test.go @@ -0,0 +1,86 @@ +package logical + +import ( + "context" + "testing" + + "github.com/go-test/deep" +) + +var keyList = []string{ + "a", + "b", + "d", + "foo", + "foo42", + "foo/a/b/c", + "c/d/e/f/g", +} + +func TestScanView(t *testing.T) { + s := prepKeyStorage(t) + + keys := make([]string, 0) + err := ScanView(context.Background(), s, func(path string) { + keys = append(keys, path) + }) + + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(keys, keyList); diff != nil { + t.Fatal(diff) + } +} + +func TestCollectKeys(t *testing.T) { + s := prepKeyStorage(t) + + keys, err := CollectKeys(context.Background(), s) + + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(keys, keyList); diff != nil { + t.Fatal(diff) + } +} + +func TestCollectKeysPrefix(t *testing.T) { + s := prepKeyStorage(t) + + keys, err := CollectKeysPrefix(context.Background(), s, "foo") + + if err != nil { + t.Fatal(err) + } + + exp := []string{ + "foo", + "foo42", + "foo/a/b/c", + } + + if diff := deep.Equal(keys, exp); diff != nil { + t.Fatal(diff) + } +} + +func prepKeyStorage(t *testing.T) Storage { + t.Helper() + s := &InmemStorage{} + + for _, key := range keyList { + if err := s.Put(context.Background(), &StorageEntry{ + Key: key, + Value: nil, + SealWrap: false, + }); err != nil { + t.Fatal(err) + } + } + + return s +} From 8d42242ff6ed464ea727a39adbb9cb1b6056cd15 Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Wed, 1 May 2019 13:45:25 -0700 Subject: [PATCH 2/2] Rename to CollectKeysWithPrefix and remove another retkeys instance --- builtin/credential/ldap/path_groups.go | 11 ++++------- builtin/credential/ldap/path_users.go | 2 +- builtin/credential/okta/path_groups.go | 2 +- sdk/logical/storage.go | 4 ++-- sdk/logical/storage_test.go | 2 +- 5 files changed, 9 insertions(+), 12 deletions(-) diff --git a/builtin/credential/ldap/path_groups.go b/builtin/credential/ldap/path_groups.go index e11810a29912..c8a33d9d5748 100644 --- a/builtin/credential/ldap/path_groups.go +++ b/builtin/credential/ldap/path_groups.go @@ -132,17 +132,14 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f } func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "group/") + keys, err := logical.CollectKeysWithPrefix(ctx, req.Storage, "group/") if err != nil { return nil, err } - retKeys := make([]string, 0) - for _, key := range keys { - if strings.HasPrefix(key, "group/") && !strings.HasPrefix(key, "/") { - retKeys = append(retKeys, strings.TrimPrefix(key, "group/")) - } + for i := range keys { + keys[i] = strings.TrimPrefix(keys[i], "group/") } - return logical.ListResponse(retKeys), nil + return logical.ListResponse(keys), nil } type GroupEntry struct { diff --git a/builtin/credential/ldap/path_users.go b/builtin/credential/ldap/path_users.go index 922191c96a2f..2eb566db325a 100644 --- a/builtin/credential/ldap/path_users.go +++ b/builtin/credential/ldap/path_users.go @@ -148,7 +148,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr } func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "user/") + keys, err := logical.CollectKeysWithPrefix(ctx, req.Storage, "user/") if err != nil { return nil, err } diff --git a/builtin/credential/okta/path_groups.go b/builtin/credential/okta/path_groups.go index e0a849c0292d..83742035035f 100644 --- a/builtin/credential/okta/path_groups.go +++ b/builtin/credential/okta/path_groups.go @@ -168,7 +168,7 @@ func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *fr } func groupList(ctx context.Context, s logical.Storage) ([]string, error) { - keys, err := logical.CollectKeysPrefix(ctx, s, "group/") + keys, err := logical.CollectKeysWithPrefix(ctx, s, "group/") if err != nil { return nil, err } diff --git a/sdk/logical/storage.go b/sdk/logical/storage.go index 0db596a84479..e95bef7537f8 100644 --- a/sdk/logical/storage.go +++ b/sdk/logical/storage.go @@ -86,11 +86,11 @@ func ScanView(ctx context.Context, view ClearableView, cb func(path string)) err // CollectKeys is used to collect all the keys in a view func CollectKeys(ctx context.Context, view ClearableView) ([]string, error) { - return CollectKeysPrefix(ctx, view, "") + return CollectKeysWithPrefix(ctx, view, "") } // CollectKeysPrefix is used to collect all the keys in a view with a given prefix string -func CollectKeysPrefix(ctx context.Context, view ClearableView, prefix string) ([]string, error) { +func CollectKeysWithPrefix(ctx context.Context, view ClearableView, prefix string) ([]string, error) { var keys []string cb := func(path string) { diff --git a/sdk/logical/storage_test.go b/sdk/logical/storage_test.go index 53a26eb20cac..aea4e8095d46 100644 --- a/sdk/logical/storage_test.go +++ b/sdk/logical/storage_test.go @@ -51,7 +51,7 @@ func TestCollectKeys(t *testing.T) { func TestCollectKeysPrefix(t *testing.T) { s := prepKeyStorage(t) - keys, err := CollectKeysPrefix(context.Background(), s, "foo") + keys, err := CollectKeysWithPrefix(context.Background(), s, "foo") if err != nil { t.Fatal(err)