Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Okta auth to allow group names containing slashes #6665

Merged
merged 2 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions builtin/credential/ldap/path_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the LDAP group.",
},

"policies": &framework.FieldSchema{
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies associated to the group.",
},
Expand Down Expand Up @@ -132,7 +132,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}

func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keys, err := logical.CollectKeys(ctx, req.Storage)
keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "group/")
if err != nil {
return nil, err
}
Expand Down
12 changes: 4 additions & 8 deletions builtin/credential/ldap/path_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}

func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keys, err := logical.CollectKeys(ctx, req.Storage)
keys, err := logical.CollectKeysPrefix(ctx, req.Storage, "user/")
if err != nil {
return nil, err
}
retKeys := make([]string, 0)
for _, key := range keys {
if strings.HasPrefix(key, "user/") && !strings.HasPrefix(key, "/") {
retKeys = append(retKeys, strings.TrimPrefix(key, "user/"))
}
for i := range keys {
kalafut marked this conversation as resolved.
Show resolved Hide resolved
keys[i] = strings.TrimPrefix(keys[i], "user/")
}
return logical.ListResponse(retKeys), nil

return logical.ListResponse(keys), nil
}

type UserEntry struct {
Expand Down
24 changes: 20 additions & 4 deletions builtin/credential/okta/path_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ func pathGroups(b *backend) *framework.Path {
return &framework.Path{
Pattern: `groups/(?P<name>.+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the Okta group.",
},

"policies": &framework.FieldSchema{
"policies": {
Type: framework.TypeCommaStringSlice,
Description: "Comma-separated list of policies associated to the group.",
},
Expand All @@ -57,10 +57,12 @@ func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*Grou
return nil, "", err
}
if entry == nil {
entries, err := s.List(ctx, "group/")
entries, err := groupList(ctx, s)
if err != nil {
return nil, "", err

}

for _, groupName := range entries {
if strings.EqualFold(groupName, n) {
entry, err = s.Get(ctx, "group/"+groupName)
Expand Down Expand Up @@ -157,13 +159,27 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}

func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List(ctx, "group/")
groups, err := groupList(ctx, req.Storage)
if err != nil {
return nil, err
}

return logical.ListResponse(groups), nil
}

func groupList(ctx context.Context, s logical.Storage) ([]string, error) {
keys, err := logical.CollectKeysPrefix(ctx, s, "group/")
if err != nil {
return nil, err
}

for i := range keys {
keys[i] = strings.TrimPrefix(keys[i], "group/")
}

return keys, nil
}

type GroupEntry struct {
Policies []string
}
Expand Down
108 changes: 108 additions & 0 deletions builtin/credential/okta/path_groups_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package okta

import (
"context"
"strings"
"testing"
"time"

"github.com/go-test/deep"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
)

func TestGroupsList(t *testing.T) {
b, storage := getBackend(t)

groups := []string{
"%20\\",
"foo",
"zfoo",
"🙂",
"foo/nested",
"foo/even/more/nested",
}

for _, group := range groups {
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "groups/" + group,
Storage: storage,
Data: map[string]interface{}{
"policies": []string{group + "_a", group + "_b"},
},
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

}

for _, group := range groups {
for _, upper := range []bool{false, true} {
groupPath := group
if upper {
groupPath = strings.ToUpper(group)
}
req := &logical.Request{
Operation: logical.ReadOperation,
Path: "groups/" + groupPath,
Storage: storage,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if resp == nil {
t.Fatal("unexpected nil response")
}

expected := []string{group + "_a", group + "_b"}

if diff := deep.Equal(resp.Data["policies"].([]string), expected); diff != nil {
t.Fatal(diff)
}
}
}

req := &logical.Request{
Operation: logical.ListOperation,
Path: "groups",
Storage: storage,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

if diff := deep.Equal(resp.Data["keys"].([]string), groups); diff != nil {
t.Fatal(diff)
}
}

func getBackend(t *testing.T) (logical.Backend, logical.Storage) {
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24

config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),

System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
StorageView: &logical.InmemStorage{},
}
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatalf("unable to create backend: %v", err)
}

return b, config.StorageView
}
15 changes: 11 additions & 4 deletions sdk/logical/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,24 @@ func ScanView(ctx context.Context, view ClearableView, cb func(path string)) err

// CollectKeys is used to collect all the keys in a view
func CollectKeys(ctx context.Context, view ClearableView) ([]string, error) {
// Accumulate the keys
var existing []string
return CollectKeysPrefix(ctx, view, "")
}

// CollectKeysPrefix is used to collect all the keys in a view with a given prefix string
func CollectKeysPrefix(ctx context.Context, view ClearableView, prefix string) ([]string, error) {
kalafut marked this conversation as resolved.
Show resolved Hide resolved
var keys []string

cb := func(path string) {
existing = append(existing, path)
if strings.HasPrefix(path, prefix) {
keys = append(keys, path)
}
}

// Scan for all the keys
if err := ScanView(ctx, view, cb); err != nil {
return nil, err
}
return existing, nil
return keys, nil
}

// ClearView is used to delete all the keys in a view
Expand Down
86 changes: 86 additions & 0 deletions sdk/logical/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package logical

import (
"context"
"testing"

"github.com/go-test/deep"
)

var keyList = []string{
"a",
"b",
"d",
"foo",
"foo42",
"foo/a/b/c",
"c/d/e/f/g",
}

func TestScanView(t *testing.T) {
s := prepKeyStorage(t)

keys := make([]string, 0)
err := ScanView(context.Background(), s, func(path string) {
keys = append(keys, path)
})

if err != nil {
t.Fatal(err)
}

if diff := deep.Equal(keys, keyList); diff != nil {
t.Fatal(diff)
}
}

func TestCollectKeys(t *testing.T) {
s := prepKeyStorage(t)

keys, err := CollectKeys(context.Background(), s)

if err != nil {
t.Fatal(err)
}

if diff := deep.Equal(keys, keyList); diff != nil {
t.Fatal(diff)
}
}

func TestCollectKeysPrefix(t *testing.T) {
s := prepKeyStorage(t)

keys, err := CollectKeysPrefix(context.Background(), s, "foo")

if err != nil {
t.Fatal(err)
}

exp := []string{
"foo",
"foo42",
"foo/a/b/c",
}

if diff := deep.Equal(keys, exp); diff != nil {
t.Fatal(diff)
}
}

func prepKeyStorage(t *testing.T) Storage {
t.Helper()
s := &InmemStorage{}

for _, key := range keyList {
if err := s.Put(context.Background(), &StorageEntry{
Key: key,
Value: nil,
SealWrap: false,
}); err != nil {
t.Fatal(err)
}
}

return s
}