diff --git a/Makefile b/Makefile index 4e228c19ca29..60fcc39c4d9a 100644 --- a/Makefile +++ b/Makefile @@ -84,6 +84,7 @@ proto: protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding protoc -I physical physical/types.proto --go_out=plugins=grpc:physical protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity + protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:. sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index a8a54a7cde01..d605153a4e5e 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "net/rpc" "strings" @@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { // This function creates a new db object from the stored configuration and // caches it in the connections map. The caller of this function needs to hold // the backend's write lock -func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { +func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) { db, ok := b.connections[name] if ok { return db, nil @@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. return nil, err } - err = db.Initialize(config.ConnectionDetails, true) + err = db.Initialize(ctx, config.ConnectionDetails, true) if err != nil { return nil, err } @@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) { func (b *databaseBackend) closeIfShutdown(name string, err error) { // Plugin has shutdown, close it so next call can reconnect. - if err == rpc.ErrShutdown { + switch err { + case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: b.Lock() b.clearConnection(name) b.Unlock() diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index f4dfd3b4862a..35d3639cd7d2 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) { RevocationStatements: defaultRevocationSQL, } - var actual dbplugin.Statements - if err := mapstructure.Decode(resp.Data, &actual); err != nil { - t.Fatal(err) + actual := dbplugin.Statements{ + CreationStatements: resp.Data["creation_statements"].(string), + RevocationStatements: resp.Data["revocation_statements"].(string), + RollbackStatements: resp.Data["rollback_statements"].(string), + RenewStatements: resp.Data["renew_statements"].(string), } if !reflect.DeepEqual(expected, actual) { diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 6df39489fea0..1d36386bc90d 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -1,10 +1,8 @@ package dbplugin import ( - "fmt" - "net/rpc" + "errors" "sync" - "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" @@ -17,11 +15,11 @@ type DatabasePluginClient struct { client *plugin.Client sync.Mutex - *databasePluginRPCClient + Database } func (dc *DatabasePluginClient) Close() error { - err := dc.databasePluginRPCClient.Close() + err := dc.Database.Close() dc.client.Kill() return err @@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. - databaseRPC := raw.(*databasePluginRPCClient) + var db Database + switch raw.(type) { + case *gRPCClient: + db = raw.(*gRPCClient) + case *databasePluginRPCClient: + logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name) + db = raw.(*databasePluginRPCClient) + default: + return nil, errors.New("unsupported client type") + } // Wrap RPC implimentation in DatabasePluginClient return &DatabasePluginClient{ - client: client, - databasePluginRPCClient: databaseRPC, + client: client, + Database: db, }, nil } - -// ---- RPC client domain ---- - -// databasePluginRPCClient implements Database and is used on the client to -// make RPC calls to a plugin. -type databasePluginRPCClient struct { - client *rpc.Client -} - -func (dr *databasePluginRPCClient) Type() (string, error) { - var dbType string - err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) - - return fmt.Sprintf("plugin-%s", dbType), err -} - -func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { - req := CreateUserRequest{ - Statements: statements, - UsernameConfig: usernameConfig, - Expiration: expiration, - } - - var resp CreateUserResponse - err = dr.client.Call("Plugin.CreateUser", req, &resp) - - return resp.Username, resp.Password, err -} - -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { - req := RenewUserRequest{ - Statements: statements, - Username: username, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { - req := RevokeUserRequest{ - Statements: statements, - Username: username, - } - - err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { - req := InitializeRequest{ - Config: conf, - VerifyConnection: verifyConnection, - } - - err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Close() error { - err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) - - return err -} diff --git a/builtin/logical/database/dbplugin/database.pb.go b/builtin/logical/database/dbplugin/database.pb.go new file mode 100644 index 000000000000..c4c4101968a6 --- /dev/null +++ b/builtin/logical/database/dbplugin/database.pb.go @@ -0,0 +1,556 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: builtin/logical/database/dbplugin/database.proto + +/* +Package dbplugin is a generated protocol buffer package. + +It is generated from these files: + builtin/logical/database/dbplugin/database.proto + +It has these top-level messages: + InitializeRequest + CreateUserRequest + RenewUserRequest + RevokeUserRequest + Statements + UsernameConfig + CreateUserResponse + TypeResponse + Empty +*/ +package dbplugin + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import google_protobuf "github.com/golang/protobuf/ptypes/timestamp" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type InitializeRequest struct { + Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"` +} + +func (m *InitializeRequest) Reset() { *m = InitializeRequest{} } +func (m *InitializeRequest) String() string { return proto.CompactTextString(m) } +func (*InitializeRequest) ProtoMessage() {} +func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *InitializeRequest) GetConfig() []byte { + if m != nil { + return m.Config + } + return nil +} + +func (m *InitializeRequest) GetVerifyConnection() bool { + if m != nil { + return m.VerifyConnection + } + return false +} + +type CreateUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` + Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` +} + +func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} } +func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) } +func (*CreateUserRequest) ProtoMessage() {} +func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *CreateUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig { + if m != nil { + return m.UsernameConfig + } + return nil +} + +func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp { + if m != nil { + return m.Expiration + } + return nil +} + +type RenewUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` + Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` +} + +func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} } +func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) } +func (*RenewUserRequest) ProtoMessage() {} +func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *RenewUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *RenewUserRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp { + if m != nil { + return m.Expiration + } + return nil +} + +type RevokeUserRequest struct { + Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` +} + +func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} } +func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) } +func (*RevokeUserRequest) ProtoMessage() {} +func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } + +func (m *RevokeUserRequest) GetStatements() *Statements { + if m != nil { + return m.Statements + } + return nil +} + +func (m *RevokeUserRequest) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +type Statements struct { + CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` + RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` + RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` + RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` +} + +func (m *Statements) Reset() { *m = Statements{} } +func (m *Statements) String() string { return proto.CompactTextString(m) } +func (*Statements) ProtoMessage() {} +func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } + +func (m *Statements) GetCreationStatements() string { + if m != nil { + return m.CreationStatements + } + return "" +} + +func (m *Statements) GetRevocationStatements() string { + if m != nil { + return m.RevocationStatements + } + return "" +} + +func (m *Statements) GetRollbackStatements() string { + if m != nil { + return m.RollbackStatements + } + return "" +} + +func (m *Statements) GetRenewStatements() string { + if m != nil { + return m.RenewStatements + } + return "" +} + +type UsernameConfig struct { + DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` + RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` +} + +func (m *UsernameConfig) Reset() { *m = UsernameConfig{} } +func (m *UsernameConfig) String() string { return proto.CompactTextString(m) } +func (*UsernameConfig) ProtoMessage() {} +func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +func (m *UsernameConfig) GetDisplayName() string { + if m != nil { + return m.DisplayName + } + return "" +} + +func (m *UsernameConfig) GetRoleName() string { + if m != nil { + return m.RoleName + } + return "" +} + +type CreateUserResponse struct { + Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` +} + +func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} } +func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) } +func (*CreateUserResponse) ProtoMessage() {} +func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + +func (m *CreateUserResponse) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *CreateUserResponse) GetPassword() string { + if m != nil { + return m.Password + } + return "" +} + +type TypeResponse struct { + Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"` +} + +func (m *TypeResponse) Reset() { *m = TypeResponse{} } +func (m *TypeResponse) String() string { return proto.CompactTextString(m) } +func (*TypeResponse) ProtoMessage() {} +func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } + +func (m *TypeResponse) GetType() string { + if m != nil { + return m.Type + } + return "" +} + +type Empty struct { +} + +func (m *Empty) Reset() { *m = Empty{} } +func (m *Empty) String() string { return proto.CompactTextString(m) } +func (*Empty) ProtoMessage() {} +func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } + +func init() { + proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") + proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") + proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") + proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") + proto.RegisterType((*Statements)(nil), "dbplugin.Statements") + proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") + proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") + proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") + proto.RegisterType((*Empty)(nil), "dbplugin.Empty") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Database service + +type DatabaseClient interface { + Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) + CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) + RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) + RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) + Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) + Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) +} + +type databaseClient struct { + cc *grpc.ClientConn +} + +func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient { + return &databaseClient{cc} +} + +func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) { + out := new(TypeResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { + out := new(CreateUserResponse) + err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Database service + +type DatabaseServer interface { + Type(context.Context, *Empty) (*TypeResponse, error) + CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) + RenewUser(context.Context, *RenewUserRequest) (*Empty, error) + RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) + Initialize(context.Context, *InitializeRequest) (*Empty, error) + Close(context.Context, *Empty) (*Empty, error) +} + +func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { + s.RegisterService(&_Database_serviceDesc, srv) +} + +func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Type(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Type", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Type(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).CreateUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/CreateUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RenewUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).RenewUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/RenewUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).RevokeUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/RevokeUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(InitializeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Initialize(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Initialize", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DatabaseServer).Close(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dbplugin.Database/Close", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DatabaseServer).Close(ctx, req.(*Empty)) + } + return interceptor(ctx, in, info, handler) +} + +var _Database_serviceDesc = grpc.ServiceDesc{ + ServiceName: "dbplugin.Database", + HandlerType: (*DatabaseServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Type", + Handler: _Database_Type_Handler, + }, + { + MethodName: "CreateUser", + Handler: _Database_CreateUser_Handler, + }, + { + MethodName: "RenewUser", + Handler: _Database_RenewUser_Handler, + }, + { + MethodName: "RevokeUser", + Handler: _Database_RevokeUser_Handler, + }, + { + MethodName: "Initialize", + Handler: _Database_Initialize_Handler, + }, + { + MethodName: "Close", + Handler: _Database_Close_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "builtin/logical/database/dbplugin/database.proto", +} + +func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 548 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, + 0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, + 0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, + 0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, + 0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, + 0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, + 0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, + 0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, + 0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, + 0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, + 0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, + 0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, + 0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, + 0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, + 0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, + 0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, + 0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, + 0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, + 0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, + 0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, + 0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, + 0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, + 0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, + 0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, + 0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, + 0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, + 0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, + 0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, + 0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, + 0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, + 0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, + 0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, + 0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, + 0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, + 0x94, 0x05, 0x00, 0x00, +} diff --git a/builtin/logical/database/dbplugin/database.proto b/builtin/logical/database/dbplugin/database.proto new file mode 100644 index 000000000000..d5e7d4068f7d --- /dev/null +++ b/builtin/logical/database/dbplugin/database.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; +package dbplugin; + +import "google/protobuf/timestamp.proto"; + +message InitializeRequest { + bytes config = 1; + bool verify_connection = 2; +} + +message CreateUserRequest { + Statements statements = 1; + UsernameConfig username_config = 2; + google.protobuf.Timestamp expiration = 3; +} + +message RenewUserRequest { + Statements statements = 1; + string username = 2; + google.protobuf.Timestamp expiration = 3; +} + +message RevokeUserRequest { + Statements statements = 1; + string username = 2; +} + +message Statements { + string creation_statements = 1; + string revocation_statements = 2; + string rollback_statements = 3; + string renew_statements = 4; +} + +message UsernameConfig { + string DisplayName = 1; + string RoleName = 2; +} + +message CreateUserResponse { + string username = 1; + string password = 2; +} + +message TypeResponse { + string type = 1; +} + +message Empty {} + +service Database { + rpc Type(Empty) returns (TypeResponse); + rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); + rpc RenewUser(RenewUserRequest) returns (Empty); + rpc RevokeUser(RevokeUserRequest) returns (Empty); + rpc Initialize(InitializeRequest) returns (Empty); + rpc Close(Empty) returns (Empty); +} diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 87dfa6c3143d..c8bbdf61d58c 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -1,6 +1,7 @@ package dbplugin import ( + "context" "time" metrics "github.com/armon/go-metrics" @@ -15,55 +16,56 @@ type databaseTracingMiddleware struct { next Database logger log.Logger - typeStr string + typeStr string + transport string } func (mw *databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } -func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) - return mw.next.CreateUser(statements, usernameConfig, expiration) + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) } -func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { +func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) - return mw.next.RenewUser(statements, username, expiration) + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport) + return mw.next.RenewUser(ctx, statements, username, expiration) } -func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { +func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) - return mw.next.RevokeUser(statements, username) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) - return mw.next.Initialize(conf, verifyConnection) + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) + return mw.next.Initialize(ctx, conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport) return mw.next.Close() } @@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } -func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) @@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) - return mw.next.CreateUser(statements, usernameConfig, expiration) + return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) } -func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { +func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) @@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s metrics.IncrCounter([]string{"database", "RenewUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) - return mw.next.RenewUser(statements, username, expiration) + return mw.next.RenewUser(ctx, statements, username, expiration) } -func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { +func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RevokeUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) @@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) - return mw.next.RevokeUser(statements, username) + return mw.next.RevokeUser(ctx, statements, username) } -func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { +func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) @@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) - return mw.next.Initialize(conf, verifyConnection) + return mw.next.Initialize(ctx, conf, verifyConnection) } func (mw *databaseMetricsMiddleware) Close() (err error) { diff --git a/builtin/logical/database/dbplugin/grpc_transport.go b/builtin/logical/database/dbplugin/grpc_transport.go new file mode 100644 index 000000000000..0b277968ced8 --- /dev/null +++ b/builtin/logical/database/dbplugin/grpc_transport.go @@ -0,0 +1,198 @@ +package dbplugin + +import ( + "context" + "encoding/json" + "errors" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/golang/protobuf/ptypes" +) + +var ( + ErrPluginShutdown = errors.New("plugin shutdown") +) + +// ---- gRPC Server domain ---- + +type gRPCServer struct { + impl Database +} + +func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) { + t, err := s.impl.Type() + if err != nil { + return nil, err + } + + return &TypeResponse{ + Type: t, + }, nil +} + +func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) { + e, err := ptypes.Timestamp(req.Expiration) + if err != nil { + return nil, err + } + + u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e) + + return &CreateUserResponse{ + Username: u, + Password: p, + }, err +} + +func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) { + e, err := ptypes.Timestamp(req.Expiration) + if err != nil { + return nil, err + } + err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e) + return &Empty{}, err +} + +func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) { + err := s.impl.RevokeUser(ctx, *req.Statements, req.Username) + return &Empty{}, err +} + +func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { + config := map[string]interface{}{} + + err := json.Unmarshal(req.Config, &config) + if err != nil { + return nil, err + } + + err = s.impl.Initialize(ctx, config, req.VerifyConnection) + return &Empty{}, err +} + +func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { + s.impl.Close() + return &Empty{}, nil +} + +// ---- gRPC client domain ---- + +type gRPCClient struct { + client DatabaseClient + clientConn *grpc.ClientConn +} + +func (c gRPCClient) Type() (string, error) { + // If the plugin has already shutdown, this will hang forever so we give it + // a one second timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return "", ErrPluginShutdown + } + resp, err := c.client.Type(ctx, &Empty{}) + if err != nil { + return "", err + } + + return resp.Type, err +} + +func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { + t, err := ptypes.TimestampProto(expiration) + if err != nil { + return "", "", err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return "", "", ErrPluginShutdown + } + + resp, err := c.client.CreateUser(ctx, &CreateUserRequest{ + Statements: &statements, + UsernameConfig: &usernameConfig, + Expiration: t, + }) + if err != nil { + return "", "", err + } + + return resp.Username, resp.Password, err +} + +func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error { + t, err := ptypes.TimestampProto(expiration) + if err != nil { + return err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + + _, err = c.client.RenewUser(ctx, &RenewUserRequest{ + Statements: &statements, + Username: username, + Expiration: t, + }) + + return err +} + +func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error { + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + _, err := c.client.RevokeUser(ctx, &RevokeUserRequest{ + Statements: &statements, + Username: username, + }) + + return err +} + +func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { + configRaw, err := json.Marshal(config) + if err != nil { + return err + } + + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + default: + return ErrPluginShutdown + } + + _, err = c.client.Initialize(ctx, &InitializeRequest{ + Config: configRaw, + VerifyConnection: verifyConnection, + }) + + return err +} + +func (c *gRPCClient) Close() error { + // If the plugin has already shutdown, this will hang forever so we give it + // a one second timeout. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + switch c.clientConn.GetState() { + case connectivity.Ready, connectivity.Idle: + _, err := c.client.Close(ctx, &Empty{}) + return err + } + + return nil +} diff --git a/builtin/logical/database/dbplugin/netrpc_transport.go b/builtin/logical/database/dbplugin/netrpc_transport.go new file mode 100644 index 000000000000..6f6f3a5bfeea --- /dev/null +++ b/builtin/logical/database/dbplugin/netrpc_transport.go @@ -0,0 +1,139 @@ +package dbplugin + +import ( + "context" + "fmt" + "net/rpc" + "time" +) + +// ---- RPC server domain ---- + +// databasePluginRPCServer implements an RPC version of Database and is run +// inside a plugin. It wraps an underlying implementation of Database. +type databasePluginRPCServer struct { + impl Database +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + var err error + *resp, err = ds.impl.Type() + return err +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error { + var err error + resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration) + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error { + err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration) + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error { + err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username) + return err +} + +func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { + err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) + return err +} + +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +// ---- RPC client domain ---- +// databasePluginRPCClient implements Database and is used on the client to +// make RPC calls to a plugin. +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() (string, error) { + var dbType string + err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType), err +} + +func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { + req := CreateUserRequestRPC{ + Statements: statements, + UsernameConfig: usernameConfig, + Expiration: expiration, + } + + var resp CreateUserResponse + err = dr.client.Call("Plugin.CreateUser", req, &resp) + + return resp.Username, resp.Password, err +} + +func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error { + req := RenewUserRequestRPC{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { + req := RevokeUserRequestRPC{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { + req := InitializeRequestRPC{ + Config: conf, + VerifyConnection: verifyConnection, + } + + err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +// ---- RPC Request Args Domain ---- + +type InitializeRequestRPC struct { + Config map[string]interface{} + VerifyConnection bool +} + +type CreateUserRequestRPC struct { + Statements Statements + UsernameConfig UsernameConfig + Expiration time.Time +} + +type RenewUserRequestRPC struct { + Statements Statements + Username string + Expiration time.Time +} + +type RevokeUserRequestRPC struct { + Statements Statements + Username string +} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 0becc9f4aa43..0f4bfee802fb 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -1,10 +1,13 @@ package dbplugin import ( + "context" "fmt" "net/rpc" "time" + "google.golang.org/grpc" + "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" log "github.com/mgutz/logxi/v1" @@ -13,29 +16,14 @@ import ( // Database is the interface that all database objects must implement. type Database interface { Type() (string, error) - CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) - RenewUser(statements Statements, username string, expiration time.Time) error - RevokeUser(statements Statements, username string) error + CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) + RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error + RevokeUser(ctx context.Context, statements Statements, username string) error - Initialize(config map[string]interface{}, verifyConnection bool) error + Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error Close() error } -// Statements set in role creation and passed into the database type's functions. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` -} - -// UsernameConfig is used to configure prefixes for the username to be -// generated. -type UsernameConfig struct { - DisplayName string - RoleName string -} - // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { @@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. return nil, err } + var transport string var db Database if pluginRunner.Builtin { // Plugin is builtin so we can retrieve an instance of the interface @@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. return nil, fmt.Errorf("unsuported database type: %s", pluginName) } + transport = "builtin" + } else { // create a DatabasePluginClient instance db, err = newPluginClient(sys, pluginRunner, logger) if err != nil { return nil, err } + + // Switch on the underlying database client type to get the transport + // method. + switch db.(*DatabasePluginClient).Database.(type) { + case *gRPCClient: + transport = "gRPC" + case *databasePluginRPCClient: + transport = "netRPC" + } + } typeStr, err := db.Type() @@ -82,9 +83,10 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. // Wrap with tracing middleware if logger.IsTrace() { db = &databaseTracingMiddleware{ - next: db, - typeStr: typeStr, - logger: logger, + transport: transport, + next: db, + typeStr: typeStr, + logger: logger, } } @@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e return &databasePluginRPCClient{client: c}, nil } -// ---- RPC Request Args Domain ---- - -type InitializeRequest struct { - Config map[string]interface{} - VerifyConnection bool -} - -type CreateUserRequest struct { - Statements Statements - UsernameConfig UsernameConfig - Expiration time.Time -} - -type RenewUserRequest struct { - Statements Statements - Username string - Expiration time.Time -} - -type RevokeUserRequest struct { - Statements Statements - Username string +func (d DatabasePlugin) GRPCServer(s *grpc.Server) error { + RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) + return nil } -// ---- RPC Response Args Domain ---- - -type CreateUserResponse struct { - Username string - Password string +func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) { + return &gRPCClient{ + client: NewDatabaseClient(c), + clientConn: c, + }, nil } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 3a785953dafa..96ef886b21d8 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,11 +1,13 @@ package dbplugin_test import ( + "context" "errors" "os" "testing" "time" + plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" @@ -20,7 +22,7 @@ type mockPlugin struct { } func (m *mockPlugin) Type() (string, error) { return "mock", nil } -func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { err = errors.New("err") if usernameConf.DisplayName == "" || expiration.IsZero() { return "", "", err @@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp return usernameConf.DisplayName, "test", nil } -func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { err := errors.New("err") if username == "" || expiration.IsZero() { return err @@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, return nil } -func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error { err := errors.New("err") if username == "" { return err @@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) delete(m.users, username) return nil } -func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { +func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { return err @@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { cores := cluster.Cores sys := vault.TestDynamicSystemView(cores[0].Core) - vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main") return cluster, sys } // This is not an actual test case, it's a helper function that will be executed // by the go-plugin client via an exec call. -func TestPlugin_Main(t *testing.T) { +func TestPlugin_GRPC_Main(t *testing.T) { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { return } @@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) { plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) } +// This is not an actual test case, it's a helper function that will be executed +// by the go-plugin client via an exec call. +func TestPlugin_NetRPC_Main(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + p := &mockPlugin{ + users: make(map[string][]string), + } + + args := []string{"--tls-skip-verify=true"} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig()) + serveConf := dbplugin.ServeConfig(p, tlsProvider) + serveConf.GRPCServer = nil + + plugin.Serve(serveConf) +} + func TestPlugin_Initialize(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() @@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(connectionDetails, true) + err = dbRaw.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) { RoleName: "test", } - us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) { // try and save the same user again to verify it saved the first time, this // should return an error - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err == nil { t.Fatal("expected an error, user wasn't created correctly") } @@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) { RoleName: "test", } - us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -216,7 +243,147 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try adding the same username back so we can verify it was removed + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +// Test the code is still compatible with an old netRPC plugin +func TestPlugin_NetRPC_Initialize(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = dbRaw.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_NetRPC_CreateUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + if us != "test" || pw != "test" { + t.Fatal("expected username and password to be 'test'") + } + + // try and save the same user again to verify it saved the first time, this + // should return an error + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("expected an error, user wasn't created correctly") + } +} + +func TestPlugin_NetRPC_RenewUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(context.Background(), connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + usernameConf := dbplugin.UsernameConfig{ + DisplayName: "test", + RoleName: "test", + } + + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_NetRPC_RevokeUser(t *testing.T) { + cluster, sys := getCluster(t) + defer cluster.Cleanup() + + db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + connectionDetails := map[string]interface{}{ + "test": 1, + } + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -226,19 +393,19 @@ func TestPlugin_RevokeUser(t *testing.T) { RoleName: "test", } - us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } // Test default revoke statememts - err = db.RevokeUser(dbplugin.Statements{}, us) + err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) if err != nil { t.Fatalf("err: %s", err) } // Try adding the same username back so we can verify it was removed - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 381f0ae2a1f4..0f44905aff29 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -10,6 +10,10 @@ import ( // Database implementation in a databasePluginRPCServer object and starts a // RPC server. func Serve(db Database, tlsProvider func() (*tls.Config, error)) { + plugin.Serve(ServeConfig(db, tlsProvider)) +} + +func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { dbPlugin := &DatabasePlugin{ impl: db, } @@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) { "database": dbPlugin, } - plugin.Serve(&plugin.ServeConfig{ + return &plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, TLSProvider: tlsProvider, - }) -} - -// ---- RPC server domain ---- - -// databasePluginRPCServer implements an RPC version of Database and is run -// inside a plugin. It wraps an underlying implementation of Database. -type databasePluginRPCServer struct { - impl Database -} - -func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - var err error - *resp, err = ds.impl.Type() - return err -} - -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { - var err error - resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { - err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { - err := ds.impl.RevokeUser(args.Statements, args.Username) - - return err -} - -func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { - err := ds.impl.Initialize(args.Config, args.VerifyConnection) - - return err -} - -func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { - ds.impl.Close() - return nil + GRPCServer: plugin.DefaultGRPCServer, + } } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index d1e6cb2923fc..95ec216d216f 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" @@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { b.clearConnection(name) // Execute plugin again, we don't need the object so throw away. - _, err := b.createDBObj(req.Storage, name) + _, err := b.createDBObj(context.TODO(), req.Storage, name) if err != nil { return nil, err } @@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } - err = db.Initialize(config.ConnectionDetails, verifyConnection) + err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection) if err != nil { db.Close() return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 610da726c606..8e1adce4a9eb 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) @@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { } // Create the user - username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration) + username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration) // Unlock unlockFunc() if err != nil { diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index c3dfcb973368..fb3e05bdf5b3 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "github.com/hashicorp/vault/logical" @@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) @@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - err := db.RenewUser(role.Statements, username, expireTime) + err := db.RenewUser(context.TODO(), role.Statements, username, expireTime) // Unlock unlockFunc() if err != nil { @@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { unlockFunc = b.Unlock // Create a new DB object - db, err = b.createDBObj(req.Storage, role.DBName) + db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) if err != nil { unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } - err = db.RevokeUser(role.Statements, username) + err = db.RevokeUser(context.TODO(), role.Statements, username) // Unlock unlockFunc() if err != nil { diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 2047651ed2af..bd9986e2db87 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin SecureConfig: secureConfig, TLSConfig: clientTLSConfig, Logger: namedLogger, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolNetRPC, + plugin.ProtocolGRPC, + }, } client := plugin.NewClient(clientConfig) diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index d94beedd3585..040b52af3f81 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) { func TestBackendHandleRequest_renewAuth(t *testing.T) { b := &Backend{} - resp, err := b.HandleRequest(logical.RenewAuthRequest( - "/foo", &logical.Auth{}, nil)) + resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) { AuthRenew: callback, } - _, err := b.HandleRequest(logical.RenewAuthRequest( - "/foo", &logical.Auth{}, nil)) + _, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) { Secrets: []*Secret{secret}, } - _, err := b.HandleRequest(logical.RenewRequest( - "/foo", secret.Response(nil, nil).Secret, nil)) + _, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil)) if err != nil { t.Fatalf("err: %s", err) } @@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) { Secrets: []*Secret{secret}, } - _, err := b.HandleRequest(logical.RevokeRequest( - "/foo", secret.Response(nil, nil).Secret, nil)) + _, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil)) if err != nil { t.Fatalf("err: %s", err) } diff --git a/logical/request.go b/logical/request.go index edde0417fd6b..5e5102d1c57f 100644 --- a/logical/request.go +++ b/logical/request.go @@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) { } // RenewRequest creates the structure of the renew request. -func RenewRequest( - path string, secret *Secret, data map[string]interface{}) *Request { +func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request { return &Request{ Operation: RenewOperation, Path: path, @@ -211,8 +210,7 @@ func RenewRequest( } // RenewAuthRequest creates the structure of the renew request for an auth. -func RenewAuthRequest( - path string, auth *Auth, data map[string]interface{}) *Request { +func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request { return &Request{ Operation: RenewOperation, Path: path, @@ -222,8 +220,7 @@ func RenewAuthRequest( } // RevokeRequest creates the structure of the revoke request. -func RevokeRequest( - path string, secret *Secret, data map[string]interface{}) *Request { +func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request { return &Request{ Operation: RevokeOperation, Path: path, diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index c0b5fd5d4254..221784e0fc31 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "strings" "time" @@ -21,6 +22,8 @@ const ( cassandraTypeName = "cassandra" ) +var _ dbplugin.Database = &Cassandra{} + // Cassandra is an implementation of Database interface type Cassandra struct { connutil.ConnectionProducer @@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) { return cassandraTypeName, nil } -func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.Connection() +func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) { + session, err := c.Connection(ctx) if err != nil { return nil, err } @@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { // CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by // the CreationStatement provided. -func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() defer c.Unlock() // Get the connection - session, err := c.getConnection() + session, err := c.getConnection(ctx) if err != nil { return "", "", err } @@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db } // RenewUser is not supported on Cassandra, so this is a no-op. -func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } // RevokeUser attempts to drop the specified user. -func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { +func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the lock c.Lock() defer c.Unlock() - session, err := c.getConnection() + session, err := c.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index 0f4d3306e3ac..c31139de7501 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "os" "strconv" "testing" @@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) { db := dbRaw.(*Cassandra) connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": "4", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*Cassandra) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/cassandra/connection_producer.go b/plugins/database/cassandra/connection_producer.go index 45d46518b6d0..2f7e06f86dab 100644 --- a/plugins/database/cassandra/connection_producer.go +++ b/plugins/database/cassandra/connection_producer.go @@ -1,6 +1,7 @@ package cassandra import ( + "context" "crypto/tls" "fmt" "strings" @@ -43,7 +44,7 @@ type cassandraConnectionProducer struct { sync.Mutex } -func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } } @@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve return nil } -func (c *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { if !c.Initialized { return nil, connutil.ErrNotInitialized } diff --git a/plugins/database/hana/hana.go b/plugins/database/hana/hana.go index aa2b53d65084..c4aaf972ff58 100644 --- a/plugins/database/hana/hana.go +++ b/plugins/database/hana/hana.go @@ -1,6 +1,7 @@ package hana import ( + "context" "database/sql" "fmt" "strings" @@ -26,6 +27,8 @@ type HANA struct { credsutil.CredentialsProducer } +var _ dbplugin.Database = &HANA{} + // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} @@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) { return hanaTypeName, nil } -func (h *HANA) getConnection() (*sql.DB, error) { - db, err := h.Connection() +func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := h.Connection(ctx) if err != nil { return nil, err } @@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) { // CreateUser generates the username/password on the underlying HANA secret backend // as instructed by the CreationStatement provided. -func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock h.Lock() defer h.Unlock() // Get the connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return "", "", err } @@ -153,9 +156,9 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi } // Renewing hana user just means altering user's valid until property -func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } @@ -193,14 +196,14 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira } // Revoking hana user will deactivate user and try to perform a soft drop -func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error { +func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // default revoke will be a soft drop on user if statements.RevocationStatements == "" { - return h.revokeUserDefault(username) + return h.revokeUserDefault(ctx, username) } // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } @@ -239,9 +242,9 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error return nil } -func (h *HANA) revokeUserDefault(username string) error { +func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { // Get connection - db, err := h.getConnection() + db, err := h.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/hana/hana_test.go b/plugins/database/hana/hana_test.go index 7cff7f1f3ab3..8845fa3b8e69 100644 --- a/plugins/database/hana/hana_test.go +++ b/plugins/database/hana/hana_test.go @@ -1,6 +1,7 @@ package hana import ( + "context" "database/sql" "fmt" "os" @@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) { CreationStatements: testHANARole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*HANA) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) { } // Test default revoke statememts - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) { } // Test custom revoke statememt - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) if err != nil { t.Fatalf("err: %s", err) } @@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) { } statements.RevocationStatements = testHANADrop - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mongodb/connection_producer.go b/plugins/database/mongodb/connection_producer.go index a9ff64a943fa..9674d10bf7c1 100644 --- a/plugins/database/mongodb/connection_producer.go +++ b/plugins/database/mongodb/connection_producer.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" @@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct { } // Initialize parses connection configuration. -func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } @@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri } // Connection creates a database connection. -func (c *mongoDBConnectionProducer) Connection() (interface{}, error) { +func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) { if !c.Initialized { return nil, connutil.ErrNotInitialized } diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index 52671dae2f5c..8b2ee802b0f2 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "io" "strings" "time" @@ -27,6 +28,8 @@ type MongoDB struct { credsutil.CredentialsProducer } +var _ dbplugin.Database = &MongoDB{} + // New returns a new MongoDB instance func New() (interface{}, error) { connProducer := &mongoDBConnectionProducer{} @@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) { return mongoDBTypeName, nil } -func (m *MongoDB) getConnection() (*mgo.Session, error) { - session, err := m.Connection() +func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) { + session, err := m.Connection(ctx) if err != nil { return nil, err } @@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) { // // JSON Example: // { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] } -func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() @@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl return "", "", dbutil.ErrEmptyCreationStatement } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl if err := m.ConnectionProducer.Close(); err != nil { return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl } // RenewUser is not supported on MongoDB, so this is a no-op. -func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } // RevokeUser drops the specified user from the authentication databse. If none is provided // in the revocation statement, the default "admin" authentication database will be assumed. -func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error { - session, err := m.getConnection() +func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { + session, err := m.getConnection(ctx) if err != nil { return err } @@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er if err := m.ConnectionProducer.Close(); err != nil { return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) } - session, err := m.getConnection() + session, err := m.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index 4c1eacb66cb5..cd948af81fbd 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -1,6 +1,7 @@ package mongodb import ( + "context" "fmt" "os" "testing" @@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) { db := dbRaw.(*MongoDB) connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { t.Fatalf("err: %s", err) } db := dbRaw.(*MongoDB) - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { } // Test default revocation statememt - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 7b920c8c991f..d43146c00739 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "strings" @@ -18,6 +19,8 @@ import ( const msSQLTypeName = "mssql" +var _ dbplugin.Database = &MSSQL{} + // MSSQL is an implementation of Database interface type MSSQL struct { connutil.ConnectionProducer @@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) { return msSQLTypeName, nil } -func (m *MSSQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() +func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := m.Connection(ctx) if err != nil { return nil, err } @@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) { // CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by // the CreationStatement provided. -func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } @@ -146,13 +149,13 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir // RevokeUser attempts to drop the specified user. It will first attempt to disable login, // then kill pending connections from that user, and finally drop the user and login from the // database instance. -func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { if statements.RevocationStatements == "" { - return m.revokeUserDefault(username) + return m.revokeUserDefault(ctx, username) } // Get connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } @@ -191,9 +194,9 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro return nil } -func (m *MSSQL) revokeUserDefault(username string) error { +func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { // Get connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 5a00890bf2b1..7d2571c3d990 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "os" @@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) { CreationStatements: testMSSQLRole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*MSSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { t.Fatal("Credentials were not revoked") } - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { // Test custom revoke statememt statements.RevocationStatements = testMSSQLDrop - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 87289274e0ee..38c928c35a59 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "strings" "time" @@ -30,6 +31,8 @@ var ( LegacyUsernameLen int = 16 ) +var _ dbplugin.Database = &MySQL{} + type MySQL struct { connutil.ConnectionProducer credsutil.CredentialsProducer @@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) { return mySQLTypeName, nil } -func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() +func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := m.Connection(ctx) if err != nil { return nil, err } @@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { // Grab the lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return "", "", err } @@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return "", "", err } @@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug "expiration": expirationStr, }) - stmt, err := tx.Prepare(query) + stmt, err := tx.PrepareContext(ctx, query) if err != nil { // If the error code we get back is Error 1295: This command is not // supported in the prepared statement protocol yet, we will execute @@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug // prepare supported commands. If there is no error when running we // will continue to the next statement. if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { - _, err = tx.Exec(query) + _, err = tx.ExecContext(ctx, query) if err != nil { return "", "", err } @@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug return "", "", err } defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + if _, err := stmt.ExecContext(ctx); err != nil { return "", "", err } } @@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug } // NOOP -func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { return nil } -func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the read lock m.Lock() defer m.Unlock() // Get the connection - db, err := m.getConnection() + db, err := m.getConnection(ctx) if err != nil { return err } @@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro } // Start a transaction - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro // 1295: This command is not supported in the prepared statement protocol yet // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.Exec(query) + _, err = tx.ExecContext(ctx, query) if err != nil { return err } diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 203158b461a1..fbbc87008095 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "fmt" "os" @@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) { db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) { CreationStatements: testMySQLRoleWildCard, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) { } // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) { // Test with a manualy prepare statement statements.CreationStatements = testMySQLRolePreparedStmt - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { CreationStatements: testMySQLRoleWildCard, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { } // Test a second time to make sure usernames don't collide - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) { dbRaw, _ := f() db := dbRaw.(*MySQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) { } statements.CreationStatements = testMySQLRoleWildCard - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) { // Test custom revoke statements statements.RevocationStatements = testMySQLRevocationSQL - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 93fa8a85425d..0a7c98e918fd 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -1,6 +1,7 @@ package postgresql import ( + "context" "database/sql" "fmt" "strings" @@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ` ) +var _ dbplugin.Database = &PostgreSQL{} + // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} @@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) { return postgreSQLTypeName, nil } -func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() +func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { + db, err := p.Connection(ctx) if err != nil { return nil, err } @@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { +func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { if statements.CreationStatements == "" { return "", "", dbutil.ErrEmptyCreationStatement } @@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d } // Get the connection - db, err := p.getConnection() + db, err := p.getConnection(ctx) if err != nil { return "", "", err @@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d return username, password, nil } -func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { +func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { p.Lock() defer p.Unlock() @@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, renewStmts = defaultPostgresRenewSQL } - db, err := p.getConnection() + db, err := p.getConnection(ctx) if err != nil { return err } @@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, return nil } -func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { +func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { // Grab the lock p.Lock() defer p.Unlock() if statements.RevocationStatements == "" { - return p.defaultRevokeUser(username) + return p.defaultRevokeUser(ctx, username) } - return p.customRevokeUser(username, statements.RevocationStatements) + return p.customRevokeUser(ctx, username, statements.RevocationStatements) } -func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { - db, err := p.getConnection() +func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { + db, err := p.getConnection(ctx) if err != nil { return err } @@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { return nil } -func (p *PostgreSQL) defaultRevokeUser(username string) error { - db, err := p.getConnection() +func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { + db, err := p.getConnection(ctx) if err != nil { return err } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 1e53d9118ab4..8f4ebb67a67c 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -1,6 +1,7 @@ package postgresql import ( + "context" "database/sql" "fmt" "os" @@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { "max_open_connections": "5", } - err = db.Initialize(connectionDetails, true) + err = db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } // Test with no configured Creation Statememt - _, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) if err == nil { t.Fatal("Expected error when no creation statement is provided") } @@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { CreationStatements: testPostgresRole, } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } statements.CreationStatements = testPostgresReadOnlyRole - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } statements.RenewStatements = defaultPostgresRenewSQL - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { t.Fatalf("Could not connect with new credentials: %s", err) } - err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { dbRaw, _ := New() db := dbRaw.(*PostgreSQL) - err := db.Initialize(connectionDetails, true) + err := db.Initialize(context.Background(), connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { RoleName: "test", } - username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } // Test default revoke statememts - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } @@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { t.Fatal("Credentials were not revoked") } - username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) + username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) if err != nil { t.Fatalf("err: %s", err) } @@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { // Test custom revoke statements statements.RevocationStatements = defaultPostgresRevocationSQL - err = db.RevokeUser(statements, username) + err = db.RevokeUser(context.Background(), statements, username) if err != nil { t.Fatalf("err: %s", err) } diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go index d36d5719d6a8..7cf23c5c3e73 100644 --- a/plugins/helper/database/connutil/connutil.go +++ b/plugins/helper/database/connutil/connutil.go @@ -1,6 +1,7 @@ package connutil import ( + "context" "errors" "sync" ) @@ -14,8 +15,8 @@ var ( // connections and is used in all the builtin database types. type ConnectionProducer interface { Close() error - Initialize(map[string]interface{}, bool) error - Connection() (interface{}, error) + Initialize(context.Context, map[string]interface{}, bool) error + Connection(context.Context) (interface{}, error) sync.Locker } diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index c325cbc187e0..2e34065d0341 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -1,6 +1,7 @@ package connutil import ( + "context" "database/sql" "fmt" "strings" @@ -25,7 +26,7 @@ type SQLConnectionProducer struct { sync.Mutex } -func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { +func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { c.Lock() defer c.Unlock() @@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo c.Initialized = true if verifyConnection { - if _, err := c.Connection(); err != nil { + if _, err := c.Connection(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } - if err := c.db.Ping(); err != nil { + if err := c.db.PingContext(ctx); err != nil { return fmt.Errorf("error verifying connection: %s", err) } } @@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo return nil } -func (c *SQLConnectionProducer) Connection() (interface{}, error) { +func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { if !c.Initialized { return nil, ErrNotInitialized } // If we already have a DB, test it and return if c.db != nil { - if err := c.db.Ping(); err == nil { + if err := c.db.PingContext(ctx); err == nil { return c.db, nil } // If the ping was unsuccessful, close it and ignore errors as we'll be diff --git a/vault/expiration.go b/vault/expiration.go index 23176ae4453f..6ebbb99b8d48 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -1017,8 +1017,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error { } // Handle standard revocation via backends - resp, err := m.router.Route(logical.RevokeRequest( - le.Path, le.Secret, le.Data)) + resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data)) if err != nil || (resp != nil && resp.IsError()) { return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err) }