diff --git a/audit/format.go b/audit/format.go index 1c0ddb3fb9af..329e7ba7125a 100644 --- a/audit/format.go +++ b/audit/format.go @@ -3,7 +3,6 @@ package audit import ( "context" "crypto/tls" - "encoding/json" "fmt" "io" "strings" @@ -15,12 +14,14 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/logical" - "github.com/mitchellh/copystructure" ) type AuditFormatWriter interface { + // WriteRequest writes the request entry to the writer or returns an error. WriteRequest(io.Writer, *AuditRequestEntry) error + // WriteResponse writes the response entry to the writer or returns an error. WriteResponse(io.Writer, *AuditResponseEntry) error + // Salt returns a non-nil salt or an error. Salt(context.Context) (*salt.Salt, error) } @@ -54,79 +55,26 @@ func (f *AuditFormatter) FormatRequest(ctx context.Context, w io.Writer, config auth := in.Auth req := in.Request var connState *tls.ConnectionState + if auth == nil { + auth = new(logical.Auth) + } if in.Request.Connection != nil && in.Request.Connection.ConnState != nil { connState = in.Request.Connection.ConnState } if !config.Raw { - // Before we copy the structure we must nil out some data - // otherwise we will cause reflection to panic and die - if connState != nil { - in.Request.Connection.ConnState = nil - defer func() { - in.Request.Connection.ConnState = connState - }() - } - - // Copy the auth structure - if in.Auth != nil { - cp, err := copystructure.Copy(in.Auth) - if err != nil { - return err - } - auth = cp.(*logical.Auth) - } - - cp, err := copystructure.Copy(in.Request) + auth, err = HashAuth(salt, auth, config.HMACAccessor) if err != nil { return err } - req = cp.(*logical.Request) - for k, v := range req.Data { - if o, ok := v.(logical.OptMarshaler); ok { - marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ - ValueHasher: salt.GetIdentifiedHMAC, - }) - if err != nil { - return err - } - req.Data[k] = json.RawMessage(marshaled) - } - } - // Hash any sensitive information - if auth != nil { - // Cache and restore accessor in the auth - var authAccessor string - if !config.HMACAccessor && auth.Accessor != "" { - authAccessor = auth.Accessor - } - if err := Hash(salt, auth, nil); err != nil { - return err - } - if authAccessor != "" { - auth.Accessor = authAccessor - } - } - - // Cache and restore accessor in the request - var clientTokenAccessor string - if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { - clientTokenAccessor = req.ClientTokenAccessor - } - if err := Hash(salt, req, in.NonHMACReqDataKeys); err != nil { + req, err = HashRequest(salt, req, config.HMACAccessor, in.NonHMACReqDataKeys) + if err != nil { return err } - if clientTokenAccessor != "" { - req.ClientTokenAccessor = clientTokenAccessor - } } - // If auth is nil, make an empty one - if auth == nil { - auth = new(logical.Auth) - } var errString string if in.OuterErr != nil { errString = in.OuterErr.Error() @@ -209,9 +157,13 @@ func (f *AuditFormatter) FormatResponse(ctx context.Context, w io.Writer, config } // Set these to the input values at first - auth := in.Auth - req := in.Request - resp := in.Response + auth, req, resp := in.Auth, in.Request, in.Response + if auth == nil { + auth = new(logical.Auth) + } + if resp == nil { + resp = new(logical.Response) + } var connState *tls.ConnectionState if in.Request.Connection != nil && in.Request.Connection.ConnState != nil { @@ -219,120 +171,22 @@ func (f *AuditFormatter) FormatResponse(ctx context.Context, w io.Writer, config } if !config.Raw { - // Before we copy the structure we must nil out some data - // otherwise we will cause reflection to panic and die - if connState != nil { - in.Request.Connection.ConnState = nil - defer func() { - in.Request.Connection.ConnState = connState - }() - } - - // Copy the auth structure - if in.Auth != nil { - cp, err := copystructure.Copy(in.Auth) - if err != nil { - return err - } - auth = cp.(*logical.Auth) - } - - cp, err := copystructure.Copy(in.Request) + auth, err = HashAuth(salt, auth, config.HMACAccessor) if err != nil { return err } - req = cp.(*logical.Request) - for k, v := range req.Data { - if o, ok := v.(logical.OptMarshaler); ok { - marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ - ValueHasher: salt.GetIdentifiedHMAC, - }) - if err != nil { - return err - } - req.Data[k] = json.RawMessage(marshaled) - } - } - - if in.Response != nil { - cp, err := copystructure.Copy(in.Response) - if err != nil { - return err - } - resp = cp.(*logical.Response) - for k, v := range resp.Data { - if o, ok := v.(logical.OptMarshaler); ok { - marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ - ValueHasher: salt.GetIdentifiedHMAC, - }) - if err != nil { - return err - } - resp.Data[k] = json.RawMessage(marshaled) - } - } - } - - // Hash any sensitive information - - // Cache and restore accessor in the auth - if auth != nil { - var accessor string - if !config.HMACAccessor && auth.Accessor != "" { - accessor = auth.Accessor - } - if err := Hash(salt, auth, nil); err != nil { - return err - } - if accessor != "" { - auth.Accessor = accessor - } - } - // Cache and restore accessor in the request - var clientTokenAccessor string - if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { - clientTokenAccessor = req.ClientTokenAccessor - } - if err := Hash(salt, req, in.NonHMACReqDataKeys); err != nil { + req, err = HashRequest(salt, req, config.HMACAccessor, in.NonHMACReqDataKeys) + if err != nil { return err } - if clientTokenAccessor != "" { - req.ClientTokenAccessor = clientTokenAccessor - } - // Cache and restore accessor in the response - if resp != nil { - var accessor, wrappedAccessor, wrappingAccessor string - if !config.HMACAccessor && resp != nil && resp.Auth != nil && resp.Auth.Accessor != "" { - accessor = resp.Auth.Accessor - } - if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { - wrappedAccessor = resp.WrapInfo.WrappedAccessor - wrappingAccessor = resp.WrapInfo.Accessor - } - if err := Hash(salt, resp, in.NonHMACRespDataKeys); err != nil { - return err - } - if accessor != "" { - resp.Auth.Accessor = accessor - } - if wrappedAccessor != "" { - resp.WrapInfo.WrappedAccessor = wrappedAccessor - } - if wrappingAccessor != "" { - resp.WrapInfo.Accessor = wrappingAccessor - } + resp, err = HashResponse(salt, resp, config.HMACAccessor, in.NonHMACRespDataKeys) + if err != nil { + return err } } - // If things are nil, make empty to avoid panics - if auth == nil { - auth = new(logical.Auth) - } - if resp == nil { - resp = new(logical.Response) - } var errString string if in.OuterErr != nil { errString = in.OuterErr.Error() diff --git a/audit/hashstructure.go b/audit/hashstructure.go index e5035581cf43..70db3f26ed10 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -1,9 +1,9 @@ package audit import ( + "encoding/json" "errors" "reflect" - "strings" "time" "github.com/hashicorp/vault/sdk/helper/salt" @@ -19,107 +19,157 @@ func HashString(salter *salt.Salt, data string) string { return salter.GetIdentifiedHMAC(data) } -// Hash will hash the given type. This has built-in support for auth, -// requests, and responses. If it is a type that isn't recognized, then -// it will be passed through. -// -// The structure is modified in-place. -func Hash(salter *salt.Salt, raw interface{}, nonHMACDataKeys []string) error { +// HashAuth returns a hashed copy of the logical.Auth input. +func HashAuth(salter *salt.Salt, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { + if in == nil { + return nil, nil + } + fn := salter.GetIdentifiedHMAC + auth := *in - switch s := raw.(type) { - case *logical.Auth: - if s == nil { - return nil - } - if s.ClientToken != "" { - s.ClientToken = fn(s.ClientToken) - } - if s.Accessor != "" { - s.Accessor = fn(s.Accessor) - } + if auth.ClientToken != "" { + auth.ClientToken = fn(auth.ClientToken) + } + if HMACAccessor && auth.Accessor != "" { + auth.Accessor = fn(auth.Accessor) + } + return &auth, nil +} - case *logical.Request: - if s == nil { - return nil - } - if s.Auth != nil { - if err := Hash(salter, s.Auth, nil); err != nil { - return err - } - } +// HashRequest returns a hashed copy of the logical.Request input. +func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { + if in == nil { + return nil, nil + } - if s.ClientToken != "" { - s.ClientToken = fn(s.ClientToken) - } + fn := salter.GetIdentifiedHMAC + req := *in - if s.ClientTokenAccessor != "" { - s.ClientTokenAccessor = fn(s.ClientTokenAccessor) + if req.Auth != nil { + cp, err := copystructure.Copy(req.Auth) + if err != nil { + return nil, err } - data, err := HashStructure(s.Data, fn, nonHMACDataKeys) + req.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) if err != nil { - return err + return nil, err } + } - s.Data = data.(map[string]interface{}) + if req.ClientToken != "" { + req.ClientToken = fn(req.ClientToken) + } + if HMACAccessor && req.ClientTokenAccessor != "" { + req.ClientTokenAccessor = fn(req.ClientTokenAccessor) + } - case *logical.Response: - if s == nil { - return nil - } + data, err := hashMap(fn, req.Data, nonHMACDataKeys) + if err != nil { + return nil, err + } + + req.Data = data + return &req, nil +} + +func hashMap(fn func(string) string, data map[string]interface{}, nonHMACDataKeys []string) (map[string]interface{}, error) { + if data == nil { + return nil, nil + } - if s.Auth != nil { - if err := Hash(salter, s.Auth, nil); err != nil { - return err + copy, err := copystructure.Copy(data) + if err != nil { + return nil, err + } + newData := copy.(map[string]interface{}) + for k, v := range newData { + if o, ok := v.(logical.OptMarshaler); ok { + marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ + ValueHasher: fn, + }) + if err != nil { + return nil, err } + newData[k] = json.RawMessage(marshaled) } + } - if s.WrapInfo != nil { - if err := Hash(salter, s.WrapInfo, nil); err != nil { - return err - } + if err := HashStructure(newData, fn, nonHMACDataKeys); err != nil { + return nil, err + } + + return newData, nil +} + +// HashResponse returns a hashed copy of the logical.Request input. +func HashResponse(salter *salt.Salt, in *logical.Response, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Response, error) { + if in == nil { + return nil, nil + } + + fn := salter.GetIdentifiedHMAC + resp := *in + + if resp.Auth != nil { + cp, err := copystructure.Copy(resp.Auth) + if err != nil { + return nil, err } - data, err := HashStructure(s.Data, fn, nonHMACDataKeys) + resp.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) if err != nil { - return err + return nil, err } + } - s.Data = data.(map[string]interface{}) + data, err := hashMap(fn, resp.Data, nonHMACDataKeys) + if err != nil { + return nil, err + } + resp.Data = data - case *wrapping.ResponseWrapInfo: - if s == nil { - return nil + if resp.WrapInfo != nil { + var err error + resp.WrapInfo, err = HashWrapInfo(salter, resp.WrapInfo, HMACAccessor) + if err != nil { + return nil, err } + } - s.Token = fn(s.Token) - s.Accessor = fn(s.Accessor) + return &resp, nil +} - if s.WrappedAccessor != "" { - s.WrappedAccessor = fn(s.WrappedAccessor) +// HashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. +func HashWrapInfo(salter *salt.Salt, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { + if in == nil { + return nil, nil + } + + fn := salter.GetIdentifiedHMAC + wrapinfo := *in + + wrapinfo.Token = fn(wrapinfo.Token) + + if HMACAccessor { + wrapinfo.Accessor = fn(wrapinfo.Accessor) + + if wrapinfo.WrappedAccessor != "" { + wrapinfo.WrappedAccessor = fn(wrapinfo.WrappedAccessor) } } - return nil + return &wrapinfo, nil } // HashStructure takes an interface and hashes all the values within // the structure. Only _values_ are hashed: keys of objects are not. // // For the HashCallback, see the built-in HashCallbacks below. -func HashStructure(s interface{}, cb HashCallback, ignoredKeys []string) (interface{}, error) { - s, err := copystructure.Copy(s) - if err != nil { - return nil, err - } - +func HashStructure(s interface{}, cb HashCallback, ignoredKeys []string) error { walker := &hashWalker{Callback: cb, IgnoredKeys: ignoredKeys} - if err := reflectwalk.Walk(s, walker); err != nil { - return nil, err - } - - return s, nil + return reflectwalk.Walk(s, walker) } // HashCallback is the callback called for HashStructure to hash @@ -134,18 +184,25 @@ type hashWalker struct { // to be hashed. If there is an error, walking will be halted // immediately and the error returned. Callback HashCallback - // IgnoreKeys are the keys that wont have the HashCallback applied IgnoredKeys []string - - key []string - lastValue reflect.Value - loc reflectwalk.Location - cs []reflect.Value - csKey []reflect.Value - csData interface{} - sliceIndex int - unknownKeys []string + // MapElem appends the key itself (not the reflect.Value) to key. + // The last element in key is the most recently entered map key. + // Since Exit pops the last element of key, only nesting to another + // structure increases the size of this slice. + key []string + lastValue reflect.Value + // Enter appends to loc and exit pops loc. The last element of loc is thus + // the current location. + loc []reflectwalk.Location + // Map and Slice append to cs, Exit pops the last element off cs. + // The last element in cs is the most recently entered map or slice. + cs []reflect.Value + // MapElem and SliceElem append to csKey. The last element in csKey is the + // most recently entered map key or slice index. Since Exit pops the last + // element of csKey, only nesting to another structure increases the size of + // this slice. + csKey []reflect.Value } // hashTimeType stores a pre-computed reflect.Type for a time.Time so @@ -155,12 +212,12 @@ type hashWalker struct { var hashTimeType = reflect.TypeOf(time.Time{}) func (w *hashWalker) Enter(loc reflectwalk.Location) error { - w.loc = loc + w.loc = append(w.loc, loc) return nil } func (w *hashWalker) Exit(loc reflectwalk.Location) error { - w.loc = reflectwalk.None + w.loc = w.loc[:len(w.loc)-1] switch loc { case reflectwalk.Map: @@ -183,7 +240,6 @@ func (w *hashWalker) Map(m reflect.Value) error { } func (w *hashWalker) MapElem(m, k, v reflect.Value) error { - w.csData = k w.csKey = append(w.csKey, k) w.key = append(w.key, k.String()) w.lastValue = v @@ -197,7 +253,6 @@ func (w *hashWalker) Slice(s reflect.Value) error { func (w *hashWalker) SliceElem(i int, elem reflect.Value) error { w.csKey = append(w.csKey, reflect.ValueOf(i)) - w.sliceIndex = i return nil } @@ -207,20 +262,37 @@ func (w *hashWalker) Struct(v reflect.Value) error { return nil } - // If we aren't in a map value, return an error to prevent a panic - if v.Interface() != w.lastValue.Interface() { - return errors.New("time.Time value in a non map key cannot be hashed for audits") + if len(w.loc) < 3 { + // The last element of w.loc is reflectwalk.Struct, by definition. + // If len(w.loc) < 3 that means hashWalker.Walk was given a struct + // value and this is the very first step in the walk, and we don't + // currently support structs as inputs, + return errors.New("structs as direct inputs not supported") } - // Create a string value of the time. IMPORTANT: this must never change - // across Vault versions or the hash value of equivalent time.Time will - // change. - strVal := v.Interface().(time.Time).Format(time.RFC3339Nano) + // Second to last element of w.loc is location that contains this struct. + switch w.loc[len(w.loc)-2] { + case reflectwalk.MapValue: + // Create a string value of the time. IMPORTANT: this must never change + // across Vault versions or the hash value of equivalent time.Time will + // change. + strVal := v.Interface().(time.Time).Format(time.RFC3339Nano) - // Set the map value to the string instead of the time.Time object - m := w.cs[len(w.cs)-1] - mk := w.csData.(reflect.Value) - m.SetMapIndex(mk, reflect.ValueOf(strVal)) + // Set the map value to the string instead of the time.Time object + m := w.cs[len(w.cs)-1] + mk := w.csKey[len(w.cs)-1] + m.SetMapIndex(mk, reflect.ValueOf(strVal)) + case reflectwalk.SliceElem: + // Create a string value of the time. IMPORTANT: this must never change + // across Vault versions or the hash value of equivalent time.Time will + // change. + strVal := v.Interface().(time.Time).Format(time.RFC3339Nano) + + // Set the map value to the string instead of the time.Time object + s := w.cs[len(w.cs)-1] + si := int(w.csKey[len(w.cs)-1].Int()) + s.Slice(si, si+1).Index(0).Set(reflect.ValueOf(strVal)) + } // Skip this entry so that we don't walk the struct. return reflectwalk.SkipEntry @@ -230,13 +302,15 @@ func (w *hashWalker) StructField(reflect.StructField, reflect.Value) error { return nil } +// Primitive calls Callback to transform strings in-place, except for map keys. +// Strings hiding within interfaces are also transformed. func (w *hashWalker) Primitive(v reflect.Value) error { if w.Callback == nil { return nil } // We don't touch map keys - if w.loc == reflectwalk.MapKey { + if w.loc[len(w.loc)-1] == reflectwalk.MapKey { return nil } @@ -244,7 +318,6 @@ func (w *hashWalker) Primitive(v reflect.Value) error { // We only care about strings if v.Kind() == reflect.Interface { - setV = v v = v.Elem() } if v.Kind() != reflect.String { @@ -260,25 +333,17 @@ func (w *hashWalker) Primitive(v reflect.Value) error { replaceVal := w.Callback(v.String()) resultVal := reflect.ValueOf(replaceVal) - switch w.loc { - case reflectwalk.MapKey: - m := w.cs[len(w.cs)-1] - - // Delete the old value - var zero reflect.Value - m.SetMapIndex(w.csData.(reflect.Value), zero) - - // Set the new key with the existing value - m.SetMapIndex(resultVal, w.lastValue) - - // Set the key to be the new key - w.csData = resultVal + switch w.loc[len(w.loc)-1] { case reflectwalk.MapValue: // If we're in a map, then the only way to set a map value is // to set it directly. m := w.cs[len(w.cs)-1] - mk := w.csData.(reflect.Value) + mk := w.csKey[len(w.cs)-1] m.SetMapIndex(mk, resultVal) + case reflectwalk.SliceElem: + s := w.cs[len(w.cs)-1] + si := int(w.csKey[len(w.cs)-1].Int()) + s.Slice(si, si+1).Index(0).Set(resultVal) default: // Otherwise, we should be addressable setV.Set(resultVal) @@ -286,34 +351,3 @@ func (w *hashWalker) Primitive(v reflect.Value) error { return nil } - -func (w *hashWalker) removeCurrent() { - // Append the key to the unknown keys - w.unknownKeys = append(w.unknownKeys, strings.Join(w.key, ".")) - - for i := 1; i <= len(w.cs); i++ { - c := w.cs[len(w.cs)-i] - switch c.Kind() { - case reflect.Map: - // Zero value so that we delete the map key - var val reflect.Value - - // Get the key and delete it - k := w.csData.(reflect.Value) - c.SetMapIndex(k, val) - return - } - } - - panic("No container found for removeCurrent") -} - -func (w *hashWalker) replaceCurrent(v reflect.Value) { - c := w.cs[len(w.cs)-2] - switch c.Kind() { - case reflect.Map: - // Get the key and delete it - k := w.csKey[len(w.csKey)-1] - c.SetMapIndex(k, v) - } -} diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 2f50ef613688..0a361c373b0b 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -3,7 +3,9 @@ package audit import ( "context" "crypto/sha256" + "encoding/json" "fmt" + "github.com/go-test/deep" "reflect" "testing" "time" @@ -111,25 +113,85 @@ func TestHashString(t *testing.T) { } } -func TestHash(t *testing.T) { - now := time.Now() - +func TestHashAuth(t *testing.T) { cases := []struct { - Input interface{} - Output interface{} - NonHMACDataKeys []string + Input *logical.Auth + Output *logical.Auth + HMACAccessor bool }{ { &logical.Auth{ClientToken: "foo"}, &logical.Auth{ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a"}, - nil, + false, }, + { + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + TTL: 1 * time.Hour, + }, + + ClientToken: "foo", + }, + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + TTL: 1 * time.Hour, + }, + + ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a", + }, + false, + }, + } + + inmemStorage := &logical.InmemStorage{} + inmemStorage.Put(context.Background(), &logical.StorageEntry{ + Key: "salt", + Value: []byte("foo"), + }) + localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ + HMAC: sha256.New, + HMACType: "hmac-sha256", + }) + if err != nil { + t.Fatalf("Error instantiating salt: %s", err) + } + for _, tc := range cases { + input := fmt.Sprintf("%#v", tc.Input) + out, err := HashAuth(localSalt, tc.Input, tc.HMACAccessor) + if err != nil { + t.Fatalf("err: %s\n\n%s", err, input) + } + if !reflect.DeepEqual(out, tc.Output) { + t.Fatalf("bad:\nInput:\n%s\nOutput:\n%#v\nExpected output:\n%#v", input, out, tc.Output) + } + } +} + +type testOptMarshaler struct { + S string + I int +} + +func (o *testOptMarshaler) MarshalJSONWithOptions(options *logical.MarshalOptions) ([]byte, error) { + return json.Marshal(&testOptMarshaler{S: options.ValueHasher(o.S), I: o.I}) +} + +var _ logical.OptMarshaler = &testOptMarshaler{} + +func TestHashRequest(t *testing.T) { + cases := []struct { + Input *logical.Request + Output *logical.Request + NonHMACDataKeys []string + HMACAccessor bool + }{ { &logical.Request{ Data: map[string]interface{}{ "foo": "bar", "baz": "foobar", "private_key_type": certutil.PrivateKeyType("rsa"), + "om": &testOptMarshaler{S: "bar", I: 1}, }, }, &logical.Request{ @@ -137,10 +199,47 @@ func TestHash(t *testing.T) { "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", "baz": "foobar", "private_key_type": "hmac-sha256:995230dca56fffd310ff591aa404aab52b2abb41703c787cfa829eceb4595bf1", + "om": json.RawMessage(`{"S":"hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317","I":1}`), }, }, []string{"baz"}, + false, }, + } + + inmemStorage := &logical.InmemStorage{} + inmemStorage.Put(context.Background(), &logical.StorageEntry{ + Key: "salt", + Value: []byte("foo"), + }) + localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ + HMAC: sha256.New, + HMACType: "hmac-sha256", + }) + if err != nil { + t.Fatalf("Error instantiating salt: %s", err) + } + for _, tc := range cases { + input := fmt.Sprintf("%#v", tc.Input) + out, err := HashRequest(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) + if err != nil { + t.Fatalf("err: %s\n\n%s", err, input) + } + if diff := deep.Equal(out, tc.Output); len(diff) > 0 { + t.Fatalf("bad:\nInput:\n%s\nDiff:\n%#v", input, diff) + } + } +} + +func TestHashResponse(t *testing.T) { + now := time.Now() + + cases := []struct { + Input *logical.Response + Output *logical.Response + NonHMACDataKeys []string + HMACAccessor bool + }{ { &logical.Response{ Data: map[string]interface{}{ @@ -149,6 +248,7 @@ func TestHash(t *testing.T) { // Responses can contain time values, so test that with // a known fixed value. "bar": now, + "om": &testOptMarshaler{S: "bar", I: 1}, }, WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, @@ -163,6 +263,7 @@ func TestHash(t *testing.T) { "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", "baz": "foobar", "bar": now.Format(time.RFC3339Nano), + "om": json.RawMessage(`{"S":"hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317","I":1}`), }, WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, @@ -173,28 +274,7 @@ func TestHash(t *testing.T) { }, }, []string{"baz"}, - }, - { - "foo", - "foo", - nil, - }, - { - &logical.Auth{ - LeaseOptions: logical.LeaseOptions{ - TTL: 1 * time.Hour, - }, - - ClientToken: "foo", - }, - &logical.Auth{ - LeaseOptions: logical.LeaseOptions{ - TTL: 1 * time.Hour, - }, - - ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a", - }, - nil, + true, }, } @@ -212,16 +292,12 @@ func TestHash(t *testing.T) { } for _, tc := range cases { input := fmt.Sprintf("%#v", tc.Input) - if err := Hash(localSalt, tc.Input, tc.NonHMACDataKeys); err != nil { + out, err := HashResponse(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) + if err != nil { t.Fatalf("err: %s\n\n%s", err, input) } - if _, ok := tc.Input.(*logical.Response); ok { - if !reflect.DeepEqual(tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) { - t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output:\n%#v", input, tc.Input.(*logical.Response).WrapInfo, tc.Output.(*logical.Response).WrapInfo) - } - } - if !reflect.DeepEqual(tc.Input, tc.Output) { - t.Fatalf("bad:\nInput:\n%s\nTest case input:\n%#v\nTest case output:\n%#v", input, tc.Input, tc.Output) + if diff := deep.Equal(out, tc.Output); len(diff) > 0 { + t.Fatalf("bad:\nInput:\n%s\nDiff:\n%#v", input, diff) } } } @@ -230,8 +306,8 @@ func TestHashWalker(t *testing.T) { replaceText := "foo" cases := []struct { - Input interface{} - Output interface{} + Input map[string]interface{} + Output map[string]interface{} }{ { map[string]interface{}{ @@ -253,14 +329,68 @@ func TestHashWalker(t *testing.T) { } for _, tc := range cases { - output, err := HashStructure(tc.Input, func(string) string { + err := HashStructure(tc.Input, func(string) string { return replaceText - }, []string{}) + }, nil) if err != nil { t.Fatalf("err: %s\n\n%#v", err, tc.Input) } - if !reflect.DeepEqual(output, tc.Output) { - t.Fatalf("bad:\n\n%#v\n\n%#v", tc.Input, output) + if !reflect.DeepEqual(tc.Input, tc.Output) { + t.Fatalf("bad:\n\n%#v\n\n%#v", tc.Input, tc.Output) + } + } +} + +func TestHashWalker_TimeStructs(t *testing.T) { + replaceText := "bar" + + now := time.Now() + cases := []struct { + Input map[string]interface{} + Output map[string]interface{} + }{ + // Should not touch map keys of type time.Time. + { + map[string]interface{}{ + "hello": map[time.Time]struct{}{ + now: {}, + }, + }, + map[string]interface{}{ + "hello": map[time.Time]struct{}{ + now: {}, + }, + }, + }, + // Should handle map values of type time.Time. + { + map[string]interface{}{ + "hello": now, + }, + map[string]interface{}{ + "hello": now.Format(time.RFC3339Nano), + }, + }, + // Should handle slice values of type time.Time. + { + map[string]interface{}{ + "hello": []interface{}{"foo", now, "foo2"}, + }, + map[string]interface{}{ + "hello": []interface{}{"foobar", now.Format(time.RFC3339Nano), "foo2bar"}, + }, + }, + } + + for _, tc := range cases { + err := HashStructure(tc.Input, func(s string) string { + return s + replaceText + }, nil) + if err != nil { + t.Fatalf("err: %v\n\n%#v", err, tc.Input) + } + if !reflect.DeepEqual(tc.Input, tc.Output) { + t.Fatalf("bad:\n\n%#v\n\n%#v", tc.Input, tc.Output) } } } diff --git a/vault/testing.go b/vault/testing.go index 3e19c85ded39..0120a1911077 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -190,9 +190,14 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo HMACType: "hmac-sha256", } config.SaltView = view - return &noopAudit{ + + n := &noopAudit{ Config: config, - }, nil + } + n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ + SaltFunc: n.Salt, + } + return n, nil }, } @@ -591,6 +596,8 @@ type noopAudit struct { Config *audit.BackendConfig salt *salt.Salt saltMutex sync.RWMutex + formatter audit.AuditFormatter + records [][]byte } func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { @@ -601,11 +608,23 @@ func (n *noopAudit) GetHash(ctx context.Context, data string) (string, error) { return salt.GetIdentifiedHMAC(data), nil } -func (n *noopAudit) LogRequest(_ context.Context, _ *logical.LogInput) error { +func (n *noopAudit) LogRequest(ctx context.Context, in *logical.LogInput) error { + var w bytes.Buffer + err := n.formatter.FormatRequest(ctx, &w, audit.FormatterConfig{}, in) + if err != nil { + return err + } + n.records = append(n.records, w.Bytes()) return nil } -func (n *noopAudit) LogResponse(_ context.Context, _ *logical.LogInput) error { +func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error { + var w bytes.Buffer + err := n.formatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in) + if err != nil { + return err + } + n.records = append(n.records, w.Bytes()) return nil } @@ -647,14 +666,13 @@ func AddNoopAudit(conf *CoreConfig) { Key: "salt", Value: []byte("foo"), }) - config.SaltConfig = &salt.Config{ - HMAC: sha256.New, - HMACType: "hmac-sha256", - } - config.SaltView = view - return &noopAudit{ + n := &noopAudit{ Config: config, - }, nil + } + n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ + SaltFunc: n.Salt, + } + return n, nil }, } } @@ -1348,6 +1366,12 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te coreConfig.DevToken = base.DevToken coreConfig.CounterSyncInterval = base.CounterSyncInterval + + } + + addAuditBackend := len(coreConfig.AuditBackends) == 0 + if addAuditBackend { + AddNoopAudit(coreConfig) } if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) { @@ -1567,6 +1591,26 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatal(err) } testCluster.ID = cluster.ID + + if addAuditBackend { + // Enable auditing. + auditReq := &logical.Request{ + Operation: logical.UpdateOperation, + ClientToken: testCluster.RootToken, + Path: "sys/audit/noop", + Data: map[string]interface{}{ + "type": "noop", + }, + } + resp, err = cores[0].HandleRequest(namespace.RootContext(ctx), auditReq) + if err != nil { + t.Fatal(err) + } + + if resp.IsError() { + t.Fatal(err) + } + } } getAPIClient := func(port int, tlsConfig *tls.Config) *api.Client {