Skip to content

Commit

Permalink
Simplify TTL/MaxTTL logic in SSH CA paths and sane with the rest of h…
Browse files Browse the repository at this point in the history
…ow (#3507)

Vault parses/returns TTLs.
  • Loading branch information
jefferai authored Oct 30, 2017
1 parent ad6b4df commit 3e81fe4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 67 deletions.
67 changes: 22 additions & 45 deletions builtin/logical/ssh/path_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func pathRoles(b *backend) *framework.Path {
`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The lease duration if no specific lease duration is
Expand All @@ -184,7 +184,7 @@ func pathRoles(b *backend) *framework.Path {
the value of max_ttl.`,
},
"max_ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
The maximum allowed lease duration
Expand Down Expand Up @@ -433,9 +433,9 @@ func (b *backend) pathRoleWrite(req *logical.Request, d *framework.FieldData) (*
}

func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework.FieldData) (*sshRole, *logical.Response) {
ttl := time.Duration(data.Get("ttl").(int)) * time.Second
maxTTL := time.Duration(data.Get("max_ttl").(int)) * time.Second
role := &sshRole{
MaxTTL: data.Get("max_ttl").(string),
TTL: data.Get("ttl").(string),
AllowedCriticalOptions: data.Get("allowed_critical_options").(string),
AllowedExtensions: data.Get("allowed_extensions").(string),
AllowUserCertificates: data.Get("allow_user_certificates").(bool),
Expand All @@ -457,44 +457,12 @@ func (b *backend) createCARole(allowedUsers, defaultUser string, data *framework
defaultCriticalOptions := convertMapToStringValue(data.Get("default_critical_options").(map[string]interface{}))
defaultExtensions := convertMapToStringValue(data.Get("default_extensions").(map[string]interface{}))

var maxTTL time.Duration
maxSystemTTL := b.System().MaxLeaseTTL()
if len(role.MaxTTL) == 0 {
maxTTL = maxSystemTTL
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid max ttl: %s", err))
}
}
if maxTTL > maxSystemTTL {
return nil, logical.ErrorResponse("Requested max TTL is higher than backend maximum")
if ttl != 0 && maxTTL != 0 && ttl > maxTTL {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" when both are specified`)
}

ttl := b.System().DefaultLeaseTTL()
if len(role.TTL) != 0 {
var err error
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, logical.ErrorResponse(fmt.Sprintf(
"Invalid ttl: %s", err))
}
}
if ttl > maxTTL {
// If they are using the system default, cap it to the role max;
// if it was specified on the command line, make it an error
if len(role.TTL) == 0 {
ttl = maxTTL
} else {
return nil, logical.ErrorResponse(
`"ttl" value must be less than "max_ttl" and/or backend default max lease TTL value`,
)
}
}

// Persist clamped TTLs
// Persist TTLs
role.TTL = ttl.String()
role.MaxTTL = maxTTL.String()
role.DefaultCriticalOptions = defaultCriticalOptions
Expand Down Expand Up @@ -551,13 +519,22 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
},
}, nil
} else if role.KeyType == KeyTypeCA {
ttl, err := parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return nil, err
}
maxTTL, err := parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return nil, err
}

return &logical.Response{
Data: map[string]interface{}{
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"max_ttl": role.MaxTTL,
"ttl": role.TTL,
"allowed_users": role.AllowedUsers,
"allowed_domains": role.AllowedDomains,
"default_user": role.DefaultUser,
"ttl": int64(ttl.Seconds()),
"max_ttl": int64(maxTTL.Seconds()),
"allowed_critical_options": role.AllowedCriticalOptions,
"allowed_extensions": role.AllowedExtensions,
"allow_user_certificates": role.AllowUserCertificates,
Expand Down
38 changes: 16 additions & 22 deletions builtin/logical/ssh/path_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func pathSign(b *backend) *framework.Path {
Description: `The desired role with configuration for this request.`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Type: framework.TypeDurationSecond,
Description: `The requested Time To Live for the SSH certificate;
sets the expiration date. If not specified
the role default, backend default, or system
Expand Down Expand Up @@ -345,40 +345,34 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole)
}

func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {

var ttl, maxTTL time.Duration
var ttlField string
ttlFieldInt, ok := data.GetOk("ttl")
if !ok {
ttlField = role.TTL
} else {
ttlField = ttlFieldInt.(string)
}
var err error

if len(ttlField) == 0 {
ttl = b.System().DefaultLeaseTTL()
ttlRaw, specifiedTTL := data.GetOk("ttl")
if specifiedTTL {
ttl = time.Duration(ttlRaw.(int)) * time.Second
} else {
var err error
ttl, err = parseutil.ParseDurationSecond(ttlField)
ttl, err = parseutil.ParseDurationSecond(role.TTL)
if err != nil {
return 0, fmt.Errorf("invalid requested ttl: %s", err)
return 0, err
}
}
if ttl == 0 {
ttl = b.System().DefaultLeaseTTL()
}

if len(role.MaxTTL) == 0 {
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, err
}
if maxTTL == 0 {
maxTTL = b.System().MaxLeaseTTL()
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, fmt.Errorf("invalid requested max ttl: %s", err)
}
}

if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if len(ttlField) == 0 {
if !specifiedTTL {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second)
Expand Down

0 comments on commit 3e81fe4

Please sign in to comment.