Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add x509 Client Auth to MongoDB Database Plugin #8329

Merged
merged 14 commits into from
Feb 13, 2020
4 changes: 2 additions & 2 deletions builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ type databaseBackend struct {
roleLocks []*locksutil.LockEntry
}

func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*dbplugin.DatabaseConfig, error) {
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, errwrap.Wrapf("failed to read connection configuration: {{err}}", err)
Expand All @@ -137,7 +137,7 @@ func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage,
return nil, fmt.Errorf("failed to find entry for connection with name: %q", name)
}

var config DatabaseConfig
var config dbplugin.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
Expand Down
30 changes: 16 additions & 14 deletions builtin/logical/database/path_config_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ var (
respErrEmptyName = "empty name attribute given"
)

// DatabaseConfig is used by the Factory function to configure a Database
// object.
type DatabaseConfig struct {
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
// ConnectionDetails stores the database specific connection settings needed
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`

RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
}

// pathResetConnection configures a path to reset a plugin.
func pathResetConnection(b *databaseBackend) *framework.Path {
return &framework.Path{
Expand Down Expand Up @@ -185,7 +173,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return nil, nil
}

var config DatabaseConfig
var config dbplugin.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
Expand All @@ -202,12 +190,26 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {

delete(config.ConnectionDetails, "password")

config = b.redact(name, config)

return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
}

func (b *databaseBackend) redact(name string, config dbplugin.DatabaseConfig) dbplugin.DatabaseConfig {
conn, exists := b.connections[name]
if !exists {
return config
}

if r, ok := conn.Database.(dbplugin.Redaction); ok {
return r.Redact(config)
}
return config
}

// connectionDeleteHandler deletes the connection configuration
func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
Expand Down Expand Up @@ -241,7 +243,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}

// Baseline
config := &DatabaseConfig{}
config := &dbplugin.DatabaseConfig{}

entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions helper/builtinplugins/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,18 @@ func newRegistry() *registry {
"alicloud": logicalAlicloud.Factory,
"aws": logicalAws.Factory,
"azure": logicalAzure.Factory,
"cassandra": logicalCass.Factory,
"cassandra": logicalCass.Factory, // Deprecated
"consul": logicalConsul.Factory,
"gcp": logicalGcp.Factory,
"gcpkms": logicalGcpKms.Factory,
"kv": logicalKv.Factory,
"mongodb": logicalMongo.Factory,
"mongodb": logicalMongo.Factory, // Deprecated
"mongodbatlas": logicalMongoAtlas.Factory,
"mssql": logicalMssql.Factory,
"mysql": logicalMysql.Factory,
"mssql": logicalMssql.Factory, // Deprecated
"mysql": logicalMysql.Factory, // Deprecated
"nomad": logicalNomad.Factory,
"pki": logicalPki.Factory,
"postgresql": logicalPostgres.Factory,
"postgresql": logicalPostgres.Factory, // Deprecated
"rabbitmq": logicalRabbit.Factory,
"ssh": logicalSsh.Factory,
"totp": logicalTotp.Factory,
Expand Down
129 changes: 129 additions & 0 deletions plugins/database/mongodb/certs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package mongodb

import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"testing"
"time"
)

type Cert struct {
Key *rsa.PrivateKey
KeyPem []byte

Cert *x509.Certificate
CertPem []byte

TLSCert tls.Certificate
}

func MakeCert(t *testing.T, parent *x509.Certificate) (bundle Cert) {
pcman312 marked this conversation as resolved.
Show resolved Hide resolved
key, keyPem := makeKey(t)

now := time.Now()

template := &x509.Certificate{
IsCA: false,
SerialNumber: makeSerial(t),
Subject: pkix.Name{
CommonName: "unittest",
},

NotBefore: now,
NotAfter: now.Add(24 * time.Hour),

KeyUsage: x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment |
x509.KeyUsageKeyAgreement,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
SubjectKeyId: getSubjKeyID(t, key),
DNSNames: []string{"localhost"},
}

if parent == nil {
parent = template
}

cert, err := x509.CreateCertificate(rand.Reader, template, parent, key.Public(), key)
if err != nil {
t.Fatalf("Unable to generate cert: %s", err)
}
x509Cert, err := x509.ParseCertificate(cert)
if err != nil {
t.Fatalf("Unable to generate cert: %s", err)
}

certPem := pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: cert,
},
)

tlsCert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
t.Fatalf("Unable to parse X509 key pair: %s", err)
}

bundle = Cert{
Key: key,
KeyPem: keyPem,
Cert: x509Cert,
CertPem: certPem,
TLSCert: tlsCert,
}
return bundle
}

func makeKey(t *testing.T) (key *rsa.PrivateKey, pemBytes []byte) {
t.Helper()

privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Unable to generate key for cert: %s", err)
}

privKeyPem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
},
)

return privKey, privKeyPem
}

func makeSerial(t *testing.T) *big.Int {
v := &big.Int{}
serialNumberLimit := v.Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
t.Fatalf("Unable to generate serial number: %s", err)
}
return serialNumber
}

// Pulled from sdk/helper/certutil & slightly modified for test usage
func getSubjKeyID(t *testing.T, privateKey crypto.Signer) []byte {
t.Helper()

if privateKey == nil {
t.Fatalf("passed-in private key is nil")
}

marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
t.Fatalf("error marshalling public key: %s", err)
}

subjKeyID := sha1.Sum(marshaledKey)

return subjKeyID[:]
}
Loading