Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault Agent Cache Auto-Auth SSRF Protection #7627

Merged
merged 29 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
853a109
implement SSRF protection header
mjarmy Oct 3, 2019
84e3086
add test for SSRF protection header
mjarmy Oct 4, 2019
b499b32
cleanup
mjarmy Oct 4, 2019
6f07e18
refactor
mjarmy Oct 8, 2019
7044ac9
merge from master
mjarmy Oct 8, 2019
9adadc5
implement SSRF header on a per-listener basis
mjarmy Oct 9, 2019
d10eef8
cleanup
mjarmy Oct 9, 2019
0e5f902
cleanup
mjarmy Oct 9, 2019
4a03773
creat unit test for agent SSRF
mjarmy Oct 10, 2019
4ca14d8
improve unit test for agent SSRF
mjarmy Oct 10, 2019
0491fec
add VaultRequest SSRF header to CLI
mjarmy Oct 10, 2019
91e22a5
merge from master
mjarmy Oct 10, 2019
2125a20
fix unit test
mjarmy Oct 10, 2019
467698c
cleanup
mjarmy Oct 10, 2019
199ff09
merge from master
mjarmy Oct 11, 2019
a915b64
improve test suite
mjarmy Oct 11, 2019
34ff6dd
simplify check for Vault-Request header
mjarmy Oct 11, 2019
6d9c4a1
add constant for Vault-Request header
mjarmy Oct 11, 2019
22e199e
improve test suite
mjarmy Oct 11, 2019
14ee72d
change 'config' to 'agentConfig'
mjarmy Oct 11, 2019
5ea8878
Revert "change 'config' to 'agentConfig'"
mjarmy Oct 11, 2019
3e6acbc
do not remove header from request
mjarmy Oct 11, 2019
15ca067
change header name to X-Vault-Request
mjarmy Oct 11, 2019
44b75af
merge from master
mjarmy Oct 11, 2019
12c8c2b
simplify http.Handler logic
mjarmy Oct 11, 2019
335c5ec
cleanup
mjarmy Oct 11, 2019
ca26eba
simplify http.Handler logic
mjarmy Oct 11, 2019
14559e5
use stdlib errors package
mjarmy Oct 11, 2019
28e9dbc
merge from master
mjarmy Oct 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,14 @@ func NewClient(c *Config) (*Client, error) {
}

client := &Client{
addr: u,
config: c,
addr: u,
config: c,
headers: make(http.Header),
}

// Add the VaultRequest SSRF protection header
client.headers["Vault-Request"] = []string{"true"}
mjarmy marked this conversation as resolved.
Show resolved Hide resolved

if token := os.Getenv(EnvVaultToken); token != "" {
client.token = token
}
Expand Down
28 changes: 20 additions & 8 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/config"
agentConfig "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/command/agent/sink/inmem"
Expand Down Expand Up @@ -192,7 +192,7 @@ func (c *AgentCommand) Run(args []string) int {
}

// Load the configuration
config, err := config.LoadConfig(c.flagConfigs[0])
config, err := agentConfig.LoadConfig(c.flagConfigs[0])
if err != nil {
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
return 1
Expand Down Expand Up @@ -418,12 +418,6 @@ func (c *AgentCommand) Run(args []string) int {
})
}

// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))

mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))

var listeners []net.Listener
for i, lnConfig := range config.Listeners {
ln, tlsConf, err := cache.StartListener(lnConfig)
Expand All @@ -434,6 +428,24 @@ func (c *AgentCommand) Run(args []string) int {

listeners = append(listeners, ln)

// Parse 'require_request_header' listener config option
var requireRequestHeader bool
if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
switch v {
case true:
requireRequestHeader = true
case false /* noop */ :
default:
c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
return 1
}
}

// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, requireRequestHeader))
mjarmy marked this conversation as resolved.
Show resolved Hide resolved

scheme := "https://"
if tlsConf == nil {
scheme = "http://"
Expand Down
16 changes: 15 additions & 1 deletion command/agent/cache/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"io/ioutil"
"net/http"
"reflect"
"time"

"github.com/hashicorp/errwrap"
Expand All @@ -20,10 +21,23 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink) http.Handler {
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink, requireRequestHeader bool) http.Handler {
mjarmy marked this conversation as resolved.
Show resolved Hide resolved
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("received request", "method", r.Method, "path", r.URL.Path)

// Check for the required request header
if requireRequestHeader {
val, ok := r.Header["Vault-Request"]
if !ok || !reflect.DeepEqual(val, []string{"true"}) {
mjarmy marked this conversation as resolved.
Show resolved Hide resolved
logical.RespondError(w, http.StatusPreconditionFailed, errors.New("missing 'Vault-Request' header"))
mjarmy marked this conversation as resolved.
Show resolved Hide resolved
return
}

// Remove the required request header
delete(r.Header, "Vault-Request")
mjarmy marked this conversation as resolved.
Show resolved Hide resolved
}

// Get token from the header
token := r.Header.Get(consts.AuthHeaderName)
if token == "" && inmemSink != nil {
logger.Debug("using auto auth token", "method", r.Method, "path", r.URL.Path)
Expand Down
3 changes: 3 additions & 0 deletions command/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type Listener struct {
Config map[string]interface{}
}

// RequireRequestHeader is a listener configuration option
const RequireRequestHeader = "require_request_header"

type AutoAuth struct {
Method *Method `hcl:"-"`
Sinks []*Sink `hcl:"sinks"`
Expand Down
228 changes: 228 additions & 0 deletions command/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package command

import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"sync"
"testing"
"time"

hclog "github.com/hashicorp/go-hclog"
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
"github.com/hashicorp/vault/command/agent"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/logging"
Expand Down Expand Up @@ -370,3 +374,227 @@ auto_auth {
t.Fatal("sink 1/2 values don't match")
}
}

func TestAgent_RequireRequestHeader(t *testing.T) {

// request is a helper function that issues HTTP requests.
request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} {
resp, err := client.RawRequest(req)
if err != nil {
t.Fatalf("err: %s", err)
}
if resp.StatusCode != expectedStatusCode {
t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode)
}

bytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err: %s", err)
}
if len(bytes) == 0 {
return nil
}

var body map[string]interface{}
err = json.Unmarshal(bytes, &body)
if err != nil {
t.Fatalf("err: %s", err)
}
return body
}

// makeTempFile is a helper function that creates a temp file and
// populates it.
makeTempFile := func(name, contents string) string {
f, err := ioutil.TempFile("", name)
if err != nil {
t.Fatal(err)
}
path := f.Name()
f.WriteString(contents)
f.Close()
return path
}

// newApiClient is a helper function that creates an *api.Client.
newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client {
conf := api.DefaultConfig()
conf.Address = addr
cli, err := api.NewClient(conf)
if err != nil {
t.Fatalf("err: %s", err)
}

if !includeVaultRequestHeader {
h := cli.Headers()
delete(h, "Vault-Request")
cli.SetHeaders(h)
}
return cli
}

//----------------------------------------------------
// Start the server and agent
//----------------------------------------------------

// Start a vault server
logger := logging.NewVaultLogger(hclog.Trace)
cluster := vault.NewTestCluster(t,
&vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
},
&vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
serverClient := cluster.Cores[0].Client

// Enable the approle auth method
req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
req.BodyBytes = []byte(`{
"type": "approle"
}`)
request(serverClient, req, 204)

// Create a named role
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
req.BodyBytes = []byte(`{
"secret_id_num_uses": "10",
"secret_id_ttl": "1m",
"token_max_ttl": "1m",
"token_num_uses": "10",
"token_ttl": "1m"
}`)
request(serverClient, req, 204)

// Fetch the RoleID of the named role
req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
body := request(serverClient, req, 200)
data := body["data"].(map[string]interface{})
roleID := data["role_id"].(string)

// Get a SecretID issued against the named role
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
body = request(serverClient, req, 200)
data = body["data"].(map[string]interface{})
secretID := data["secret_id"].(string)

// Write the RoleID and SecretID to temp files
roleIDPath := makeTempFile("role_id.txt", roleID+"\n")
secretIDPath := makeTempFile("secret_id.txt", secretID+"\n")
defer os.Remove(roleIDPath)
defer os.Remove(secretIDPath)

// Get a temp file path we can use for the sink
sinkPath := makeTempFile("sink.txt", "")
defer os.Remove(sinkPath)

// Create a config file
config := `
auto_auth {
method "approle" {
mount_path = "auth/approle"
config = {
role_id_file_path = "%s"
secret_id_file_path = "%s"
}
}

sink "file" {
config = {
path = "%s"
}
}
}

cache {
use_auto_auth_token = true
}

listener "tcp" {
address = "127.0.0.1:8101"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:8102"
tls_disable = true
require_request_header = false
}
listener "tcp" {
address = "127.0.0.1:8103"
tls_disable = true
require_request_header = true
}
`
config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath)
configPath := makeTempFile("config.hcl", config)
defer os.Remove(configPath)

// Start the agent
ui, cmd := testAgentCommand(t, logger)
cmd.client = serverClient
cmd.startedCh = make(chan struct{})

wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
code := cmd.Run([]string{"-config", configPath})
if code != 0 {
t.Errorf("non-zero return code when running agent: %d", code)
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
}
wg.Done()
}()

select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
}

// defer agent shutdown
defer func() {
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}()

//----------------------------------------------------
// Perform the tests
//----------------------------------------------------

// Test against a listener configuration that omits
// 'require_request_header', with the header missing from the request.
agentClient := newApiClient("http://127.0.0.1:8101", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'false', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8102", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the header missing from the request.
agentClient = newApiClient("http://127.0.0.1:8103", false)
req = agentClient.NewRequest("GET", "/v1/sys/health")
resp, err := agentClient.RawRequest(req)
if err == nil {
t.Fatalf("expected error")
}
if resp.StatusCode != 412 {
mjarmy marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("expected status code %d, not %d", 412, resp.StatusCode)
}

// Test against a listener configuration that sets 'require_request_header'
// to 'true', with the header present in the request.
agentClient = newApiClient("http://127.0.0.1:8103", true)
req = agentClient.NewRequest("GET", "/v1/sys/health")
request(agentClient, req, 200)
}