Skip to content

Commit

Permalink
Allow non-strings to be used to set ttl field in generic.
Browse files Browse the repository at this point in the history
Fixes #2697
  • Loading branch information
jefferai committed May 9, 2017
1 parent 6f030e4 commit f3f8a90
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
15 changes: 7 additions & 8 deletions vault/logical_passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"fmt"
"strings"

"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -126,14 +126,13 @@ func (b *PassthroughBackend) handleRead(
}

// Check if there is a ttl key
var ttl string
ttl, _ = rawData["ttl"].(string)
if len(ttl) == 0 {
ttl, _ = rawData["lease"].(string)
}
ttlDuration := b.System().DefaultLeaseTTL()
if len(ttl) != 0 {
dur, err := parseutil.ParseDurationSecond(ttl)
ttlInt, ok := rawData["ttl"]
if !ok {
ttlInt, ok = rawData["lease"]
}
if ok {
dur, err := parseutil.ParseDurationSecond(ttlInt)
if err == nil {
ttlDuration = dur
}
Expand Down
47 changes: 38 additions & 9 deletions vault/logical_passthrough_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package vault

import (
"encoding/json"
"reflect"
"testing"
"time"

"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/logical"
)

Expand Down Expand Up @@ -49,10 +51,19 @@ func TestPassthroughBackend_Write(t *testing.T) {
}

func TestPassthroughBackend_Read(t *testing.T) {
test := func(b logical.Backend, ttlType string, leased bool) {
test := func(b logical.Backend, ttlType string, ttl interface{}, leased bool) {
req := logical.TestRequest(t, logical.UpdateOperation, "foo")
req.Data["raw"] = "test"
req.Data[ttlType] = "1h"
var reqTTL interface{}
switch ttl.(type) {
case int64:
reqTTL = ttl.(int64)
case string:
reqTTL = ttl.(string)
default:
t.Fatal("unknown ttl type")
}
req.Data[ttlType] = reqTTL
storage := req.Storage

if _, err := b.HandleRequest(req); err != nil {
Expand All @@ -67,16 +78,34 @@ func TestPassthroughBackend_Read(t *testing.T) {
t.Fatalf("err: %v", err)
}

expectedTTL, err := parseutil.ParseDurationSecond(ttl)
if err != nil {
t.Fatal(err)
}

// What comes back if an int is passed in is a json.Number which is
// actually aliased as a string so to make the deep equal happy if it's
// actually a number we set it to an int64
var respTTL interface{} = resp.Data[ttlType]
_, ok := respTTL.(json.Number)
if ok {
respTTL, err = respTTL.(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
resp.Data[ttlType] = respTTL
}

expected := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
Renewable: true,
TTL: time.Hour,
TTL: expectedTTL,
},
},
Data: map[string]interface{}{
"raw": "test",
ttlType: "1h",
ttlType: reqTTL,
},
}

Expand All @@ -86,15 +115,15 @@ func TestPassthroughBackend_Read(t *testing.T) {
resp.Secret.InternalData = nil
resp.Secret.LeaseID = ""
if !reflect.DeepEqual(resp, expected) {
t.Fatalf("bad response.\n\nexpected: %#v\n\nGot: %#v", expected, resp)
t.Fatalf("bad response.\n\nexpected:\n%#v\n\nGot:\n%#v", expected, resp)
}
}
b := testPassthroughLeasedBackend()
test(b, "lease", true)
test(b, "ttl", true)
test(b, "lease", "1h", true)
test(b, "ttl", "5", true)
b = testPassthroughBackend()
test(b, "lease", false)
test(b, "ttl", false)
test(b, "lease", int64(10), false)
test(b, "ttl", "40s", false)
}

func TestPassthroughBackend_Delete(t *testing.T) {
Expand Down

0 comments on commit f3f8a90

Please sign in to comment.