From f3f8a90407bb3dcbf9e20379244d93f53ed941b9 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 9 May 2017 12:22:38 -0400 Subject: [PATCH] Allow non-strings to be used to set `ttl` field in generic. Fixes #2697 --- vault/logical_passthrough.go | 15 +++++----- vault/logical_passthrough_test.go | 47 +++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/vault/logical_passthrough.go b/vault/logical_passthrough.go index eb52a3f62c9e..cd936690fb17 100644 --- a/vault/logical_passthrough.go +++ b/vault/logical_passthrough.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -126,14 +126,13 @@ func (b *PassthroughBackend) handleRead( } // Check if there is a ttl key - var ttl string - ttl, _ = rawData["ttl"].(string) - if len(ttl) == 0 { - ttl, _ = rawData["lease"].(string) - } ttlDuration := b.System().DefaultLeaseTTL() - if len(ttl) != 0 { - dur, err := parseutil.ParseDurationSecond(ttl) + ttlInt, ok := rawData["ttl"] + if !ok { + ttlInt, ok = rawData["lease"] + } + if ok { + dur, err := parseutil.ParseDurationSecond(ttlInt) if err == nil { ttlDuration = dur } diff --git a/vault/logical_passthrough_test.go b/vault/logical_passthrough_test.go index bd33d657b3f5..b7bc3999e3d2 100644 --- a/vault/logical_passthrough_test.go +++ b/vault/logical_passthrough_test.go @@ -1,10 +1,12 @@ package vault import ( + "encoding/json" "reflect" "testing" "time" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/logical" ) @@ -49,10 +51,19 @@ func TestPassthroughBackend_Write(t *testing.T) { } func TestPassthroughBackend_Read(t *testing.T) { - test := func(b logical.Backend, ttlType string, leased bool) { + test := func(b logical.Backend, ttlType string, ttl interface{}, leased bool) { req := logical.TestRequest(t, logical.UpdateOperation, "foo") req.Data["raw"] = "test" - req.Data[ttlType] = "1h" + var reqTTL interface{} + switch ttl.(type) { + case int64: + reqTTL = ttl.(int64) + case string: + reqTTL = ttl.(string) + default: + t.Fatal("unknown ttl type") + } + req.Data[ttlType] = reqTTL storage := req.Storage if _, err := b.HandleRequest(req); err != nil { @@ -67,16 +78,34 @@ func TestPassthroughBackend_Read(t *testing.T) { t.Fatalf("err: %v", err) } + expectedTTL, err := parseutil.ParseDurationSecond(ttl) + if err != nil { + t.Fatal(err) + } + + // What comes back if an int is passed in is a json.Number which is + // actually aliased as a string so to make the deep equal happy if it's + // actually a number we set it to an int64 + var respTTL interface{} = resp.Data[ttlType] + _, ok := respTTL.(json.Number) + if ok { + respTTL, err = respTTL.(json.Number).Int64() + if err != nil { + t.Fatal(err) + } + resp.Data[ttlType] = respTTL + } + expected := &logical.Response{ Secret: &logical.Secret{ LeaseOptions: logical.LeaseOptions{ Renewable: true, - TTL: time.Hour, + TTL: expectedTTL, }, }, Data: map[string]interface{}{ "raw": "test", - ttlType: "1h", + ttlType: reqTTL, }, } @@ -86,15 +115,15 @@ func TestPassthroughBackend_Read(t *testing.T) { resp.Secret.InternalData = nil resp.Secret.LeaseID = "" if !reflect.DeepEqual(resp, expected) { - t.Fatalf("bad response.\n\nexpected: %#v\n\nGot: %#v", expected, resp) + t.Fatalf("bad response.\n\nexpected:\n%#v\n\nGot:\n%#v", expected, resp) } } b := testPassthroughLeasedBackend() - test(b, "lease", true) - test(b, "ttl", true) + test(b, "lease", "1h", true) + test(b, "ttl", "5", true) b = testPassthroughBackend() - test(b, "lease", false) - test(b, "ttl", false) + test(b, "lease", int64(10), false) + test(b, "ttl", "40s", false) } func TestPassthroughBackend_Delete(t *testing.T) {