Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport [1.7.x]: Cassandra: Refactor PEM parsing logic (#11861) #11921

Merged
merged 1 commit into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/11861.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
secrets/database/cassandra: Fixed issue where the PEM parsing logic of `pem_bundle` and `pem_json` didn't work for CA-only configurations
```
53 changes: 42 additions & 11 deletions helper/testhelpers/cassandra/cassandrahelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,34 @@ import (
)

type containerConfig struct {
version string
copyFromTo map[string]string
sslOpts *gocql.SslOptions
containerName string
imageName string
version string
copyFromTo map[string]string
env []string

sslOpts *gocql.SslOptions
}

type ContainerOpt func(*containerConfig)

func ContainerName(name string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.containerName = name
}
}

func Image(imageName string, version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.imageName = imageName
cfg.version = version

// Reset the environment because there's a very good chance the default environment doesn't apply to the
// non-default image being used
cfg.env = nil
}
}

func Version(version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.version = version
Expand All @@ -33,6 +54,12 @@ func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
}
}

func Env(keyValue string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.env = append(cfg.env, keyValue)
}
}

func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
return func(cfg *containerConfig) {
cfg.sslOpts = sslOpts
Expand Down Expand Up @@ -63,7 +90,9 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
}

containerCfg := &containerConfig{
version: "3.11",
imageName: "cassandra",
version: "3.11",
env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
}

for _, opt := range opts {
Expand All @@ -79,13 +108,15 @@ func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
copyFromTo[absFrom] = to
}

runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "cassandra",
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
})
runOpts := docker.RunOptions{
ContainerName: containerCfg.containerName,
ImageRepo: containerCfg.imageName,
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: containerCfg.env,
}
runner, err := docker.NewServiceRunner(runOpts)
if err != nil {
t.Fatalf("Could not start docker cassandra: %s", err)
}
Expand Down
50 changes: 31 additions & 19 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,27 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
}

func TestInitialize(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()
t.Run("integer protocol version", func(t *testing.T) {
// getCassandra performs an Initialize call
db, cleanup := getCassandra(t, 4)
t.Cleanup(cleanup)

err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
})

db, cleanup = getCassandra(t, "4")
defer cleanup()
t.Run("string protocol version", func(t *testing.T) {
// getCassandra performs an Initialize call
db, cleanup := getCassandra(t, "4")
t.Cleanup(cleanup)

err := db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

func TestCreateUser(t *testing.T) {
Expand All @@ -74,7 +85,7 @@ func TestCreateUser(t *testing.T) {
newUserReq dbplugin.NewUserRequest
expectErr bool
expectedUsernameRegex string
assertCreds func(t testing.TB, address string, port int, username, password string, timeout time.Duration)
assertCreds func(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration)
}

tests := map[string]testCase{
Expand Down Expand Up @@ -160,7 +171,7 @@ func TestCreateUser(t *testing.T) {
t.Fatalf("no error expected, got: %s", err)
}
require.Regexp(t, test.expectedUsernameRegex, newUserResp.Username)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, 5*time.Second)
test.assertCreds(t, db.Hosts, db.Port, newUserResp.Username, test.newUserReq.Password, nil, 5*time.Second)
})
}
}
Expand All @@ -184,7 +195,7 @@ func TestUpdateUserPassword(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

newPassword := "somenewpassword"
updateReq := dbplugin.UpdateUserRequest{
Expand All @@ -198,7 +209,7 @@ func TestUpdateUserPassword(t *testing.T) {

dbtesting.AssertUpdateUser(t, db, updateReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, nil, 5*time.Second)
}

func TestDeleteUser(t *testing.T) {
Expand All @@ -220,21 +231,21 @@ func TestDeleteUser(t *testing.T) {

createResp := dbtesting.AssertNewUser(t, db, createReq)

assertCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)

deleteReq := dbplugin.DeleteUserRequest{
Username: createResp.Username,
}

dbtesting.AssertDeleteUser(t, db, deleteReq)

assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, 5*time.Second)
assertNoCreds(t, db.Hosts, db.Port, createResp.Username, password, nil, 5*time.Second)
}

func assertCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()
op := func() error {
return connect(t, address, port, username, password)
return connect(t, address, port, username, password, sslOpts)
}
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = timeout
Expand All @@ -248,7 +259,7 @@ func assertCreds(t testing.TB, address string, port int, username, password stri
}
}

func connect(t testing.TB, address string, port int, username, password string) error {
func connect(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions) error {
t.Helper()
clusterConfig := gocql.NewCluster(address)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Expand All @@ -257,6 +268,7 @@ func connect(t testing.TB, address string, port int, username, password string)
}
clusterConfig.ProtoVersion = 4
clusterConfig.Port = port
clusterConfig.SslOpts = sslOpts

session, err := clusterConfig.CreateSession()
if err != nil {
Expand All @@ -266,12 +278,12 @@ func connect(t testing.TB, address string, port int, username, password string)
return nil
}

func assertNoCreds(t testing.TB, address string, port int, username, password string, timeout time.Duration) {
func assertNoCreds(t testing.TB, address string, port int, username, password string, sslOpts *gocql.SslOptions, timeout time.Duration) {
t.Helper()

op := func() error {
// "Invert" the error so the backoff logic sees a failure to connect as a success
err := connect(t, address, port, username, password)
err := connect(t, address, port, username, password, sslOpts)
if err != nil {
return nil
}
Expand Down
100 changes: 27 additions & 73 deletions plugins/database/cassandra/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -40,7 +39,7 @@ type cassandraConnectionProducer struct {

connectTimeout time.Duration
socketKeepAlive time.Duration
certBundle *certutil.CertBundle
sslOpts *gocql.SslOptions
rawConfig map[string]interface{}

Initialized bool
Expand Down Expand Up @@ -83,38 +82,46 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
return fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return fmt.Errorf("password cannot be empty")
case len(c.PemJSON) > 0 && len(c.PemBundle) > 0:
return fmt.Errorf("cannot specify both pem_json and pem_bundle")
}

var tlsMinVersion uint16 = tls.VersionTLS12
if c.TLSMinVersion != "" {
ver, exists := tlsutil.TLSLookup[c.TLSMinVersion]
if !exists {
return fmt.Errorf("unrecognized TLS version [%s]", c.TLSMinVersion)
}
tlsMinVersion = ver
}

var certBundle *certutil.CertBundle
var parsedCertBundle *certutil.ParsedCertBundle
switch {
case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
cfg, err := jsonBundleToTLSConfig(c.PemJSON, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %w", err)
return fmt.Errorf("failed to parse pem_json: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true

case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
cfg, err := pemBundleToTLSConfig(c.PemBundle, tlsMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return fmt.Errorf("error parsing the given PEM information: %w", err)
return fmt.Errorf("failed to parse pem_bundle: %w", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
c.sslOpts = &gocql.SslOptions{
Config: cfg,
EnableHostVerification: !cfg.InsecureSkipVerify,
}
c.certBundle = certBundle
c.TLS = true
}

if c.InsecureTLS {
c.TLS = true
case c.InsecureTLS:
c.sslOpts = &gocql.SslOptions{
EnableHostVerification: !c.InsecureTLS,
}
}

// Set initialized to true at this point since all fields are set,
Expand Down Expand Up @@ -183,14 +190,7 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql

clusterConfig.Timeout = c.connectTimeout
clusterConfig.SocketKeepalive = c.socketKeepAlive

if c.TLS {
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.TLSServerName, c.InsecureTLS)
if err != nil {
return nil, err
}
clusterConfig.SslOpts = sslOpts
}
clusterConfig.SslOpts = c.sslOpts

if c.LocalDatacenter != "" {
clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
Expand Down Expand Up @@ -231,52 +231,6 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
return session, nil
}

func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion, serverName string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
tlsConfig := &tls.Config{}
if certBundle != nil {
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
}
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}

parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}

tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
}

tlsConfig.InsecureSkipVerify = insecureSkipVerify

if serverName != "" {
tlsConfig.ServerName = serverName
}

if minTLSVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}

opts := &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !insecureSkipVerify,
}
return opts, nil
}

func (c *cassandraConnectionProducer) secretValues() map[string]string {
return map[string]string{
c.Password: "[password]",
Expand Down
Loading