Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for distributed groups claims on Azure #120

Merged
merged 6 commits into from
Jun 25, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/hashicorp/vault-plugin-auth-jwt

go 1.13
go 1.14

require (
github.com/coreos/go-oidc v2.1.0+incompatible
Expand All @@ -19,7 +19,7 @@ require (
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/ryanuber/go-glob v1.0.0
github.com/stretchr/testify v1.3.0
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
golang.org/x/sync v0.0.0-20190423024810-112230192c58
golang.org/x/text v0.3.2 // indirect
google.golang.org/appengine v1.5.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7 h1:fHDIZ2oxGnUZRN6WgWFCbYBjH9uqVPRCUVUDhs0wnbA=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down
4 changes: 2 additions & 2 deletions path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ func TestConfig_OIDC_Write_ProviderConfig(t *testing.T) {
req.Data = map[string]interface{}{
"oidc_discovery_url": "https://team-vault.auth0.com/",
"provider_config": map[string]interface{}{
"provider": "empty",
"provider": "azure",
"extraOptions": "abound",
},
}
Expand All @@ -430,7 +430,7 @@ func TestConfig_OIDC_Write_ProviderConfig(t *testing.T) {
OIDCResponseTypes: []string{},
OIDCDiscoveryURL: "https://team-vault.auth0.com/",
ProviderConfig: map[string]interface{}{
"provider": "empty",
"provider": "azure",
"extraOptions": "abound",
},
}
Expand Down
29 changes: 25 additions & 4 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,9 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return alias, groupAliases, nil
}

groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
return nil, nil, fmt.Errorf("%q claim not found in token", role.GroupsClaim)
groupsClaimRaw, err := b.fetchGroups(allClaims, role)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
}

groups, ok := normalizeList(groupsClaimRaw)
Expand All @@ -361,6 +360,28 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return alias, groupAliases, nil
}

// Checks if there's a custom provider_config and calls FetchGroups() if implemented
func (b *jwtAuthBackend) fetchGroups(allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing that the custom provider Initialize() method within NewProviderConfig() will end up being called upon each login. I think that's okay given that the azure provider doesn't do much there. I'm wondering if we would still want that to happen if there is a custom provider with more or potentially slow initialization code. It appears that the customer provider is already initialized when the backend gets its config.

Not a blocker. Just wanted to see what your thoughts were :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I guess I envisioned the Initialize() method would mainly be used to translate the "provider_config" values into the CustomProvider struct, and not actually run any API calls against a provider (azure in this case). But it's something to keep in mind as this evolves for sure.

if err != nil {
return nil, fmt.Errorf("failed to load custom provider config: %s", err)
}
// If the custom provider implements interface GroupsFetcher, call it,
// otherwise fall through to the default method
if pConfig != nil {
if gf, ok := pConfig.(GroupsFetcher); ok {
return gf.FetchGroups(b, allClaims, role)
}
}
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
return nil, fmt.Errorf("%q claim not found in token", role.GroupsClaim)
}

return groupsClaimRaw, nil
}

const (
pathLoginHelpSyn = `
Authenticates to Vault using a JWT (or OIDC) token.
Expand Down
2 changes: 2 additions & 0 deletions path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"gopkg.in/square/go-jose.v2/jwt"
)

type H map[string]interface{}

type testConfig struct {
oidc bool
role_type_oidc bool
Expand Down
185 changes: 185 additions & 0 deletions provider_azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package jwtauth

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"

"github.com/coreos/go-oidc"
log "github.com/hashicorp/go-hclog"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

const (
// The old MS graph API requires setting an api-version query parameter
windowsGraphHost = "graph.windows.net"
windowsAPIVersion = "1.6"

// Distributed claim fields
claimNamesField = "_claim_names"
claimSourcesField = "_claim_sources"
)

// AzureProvider is used for Azure-specific configuration
type AzureProvider struct {
// Context for azure calls
ctx context.Context

// OIDC provider
provider *oidc.Provider
}

// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface
func (a *AzureProvider) Initialize(jc *jwtConfig) error {
return nil
}

// SensitiveKeys - satisfying the CustomProvider interface
func (a *AzureProvider) SensitiveKeys() []string {
return []string{}
}

// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
func (a *AzureProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
// If the "groups" claim is missing, it might be because the user is a
// member of more than 200 groups, which means the token contains
// distributed claim information. Attempt to look that up here.
azureClaimSourcesURL, err := a.getClaimSource(b.Logger(), allClaims, role)
if err != nil {
return nil, fmt.Errorf("unable to get claim sources: %s", err)
}

// Get provider because we'll need to get a new token for microsoft's
// graph API, specifically the old graph API
provider, err := b.getProvider(b.cachedConfig)
if err != nil {
return nil, fmt.Errorf("unable to get provider: %s", err)
}
a.provider = provider

a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM)
if err != nil {
return nil, fmt.Errorf("unable to create CA Context: %s", err)
}

azureGroups, err := a.getAzureGroups(b.Logger(), azureClaimSourcesURL, b.cachedConfig)
if err != nil {
return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err)
}
groupsClaimRaw = azureGroups
}
b.Logger().Debug(fmt.Sprintf("groups claim raw is %v", groupsClaimRaw))
return groupsClaimRaw, nil
}

// In Azure, if you are indirectly member of more than 200 groups, they will
// send _claim_names and _claim_sources instead of the groups, per OIDC Core
// 1.0, section 5.6.2:
// https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims
// In the future this could be used with other providers as well. Example:
//
// {
// "_claim_names": {
// "groups": "src1"
// },
// "_claim_sources": {
// "src1": {
// "endpoint": "https://graph.windows.net...."
// }
// }
// }
//
// For this to work, "profile" should be set in "oidc_scopes" in the vault oidc role.
//
func (a *AzureProvider) getClaimSource(logger log.Logger, allClaims map[string]interface{}, role *jwtRole) (string, error) {
// Get the source key for the groups claim
name := fmt.Sprintf("/%s/%s", claimNamesField, role.GroupsClaim)
groupsClaimSource := getClaim(logger, allClaims, name)
if groupsClaimSource == nil {
return "", fmt.Errorf("unable to locate groups claim %q in %s", role.GroupsClaim, claimNamesField)
}
// Get the endpoint source for the groups claim
endpoint := fmt.Sprintf("/%s/%s/endpoint", claimSourcesField, groupsClaimSource.(string))
val := getClaim(logger, allClaims, endpoint)
if val == nil {
return "", fmt.Errorf("unable to locate %s in claims", endpoint)
}
logger.Debug(fmt.Sprintf("found Azure Graph API endpoint for group membership: %v", val))
return fmt.Sprintf("%v", val), nil
}

// Fetch user groups from the Azure AD Graph API
func (a *AzureProvider) getAzureGroups(logger log.Logger, groupsURL string, c *jwtConfig) (interface{}, error) {
urlParsed, err := url.Parse(groupsURL)
if err != nil {
return nil, fmt.Errorf("failed to parse distributed groups source url %s: %s", groupsURL, err)
}
token, err := a.getAzureToken(logger, c, urlParsed.Host)
if err != nil {
return nil, fmt.Errorf("unable to get token: %s", err)
}
payload := strings.NewReader("{\"securityEnabledOnly\": false}")
req, _ := http.NewRequest("POST", groupsURL, payload)
austingebauer marked this conversation as resolved.
Show resolved Hide resolved
req.Header.Add("content-type", "application/json")
req.Header.Add("authorization", fmt.Sprintf("Bearer %s", token))

// If endpoint is the old windows graph api, add api-version
if urlParsed.Host == windowsGraphHost {
query := req.URL.Query()
query.Add("api-version", windowsAPIVersion)
req.URL.RawQuery = query.Encode()
}
client := http.DefaultClient
if c, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
client = c
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to call Azure AD Graph API: %s", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("failed to read Azure AD Graph API response: %s", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get groups: %s", string(body))
}

var target azureGroups
if err := json.Unmarshal(body, &target); err != nil {
return nil, fmt.Errorf("unabled to decode response: %s", err)
}
return target.Value, nil
}

// Login to Azure, using client id and secret.
func (a *AzureProvider) getAzureToken(logger log.Logger, c *jwtConfig, host string) (string, error) {
austingebauer marked this conversation as resolved.
Show resolved Hide resolved
config := &clientcredentials.Config{
ClientID: c.OIDCClientID,
ClientSecret: c.OIDCClientSecret,
TokenURL: a.provider.Endpoint().TokenURL,
Scopes: []string{
"openid",
"profile",
"https://" + host + "/.default",
},
}
token, err := config.Token(a.ctx)
if err != nil {
return "", fmt.Errorf("failed to fetch Azure token: %s", err)
}
return token.AccessToken, nil
}

type azureGroups struct {
Value []interface{} `json:"value"`
}
Loading