Skip to content

Commit

Permalink
Run end-to-end tests in parallel.
Browse files Browse the repository at this point in the history
The server accesses external dependencies via the context, instead of global variables, enabling the test parallelism.
  • Loading branch information
armsnyder committed Nov 7, 2020
1 parent 6f9fac7 commit 72dda22
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 75 deletions.
97 changes: 97 additions & 0 deletions pkg/server/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package server

import (
"context"
"encoding/json"
"fmt"
"log"
"os"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigatewaymanagementapi"
"github.com/aws/aws-sdk-go/service/dynamodb"
)

// This file contains methods for configuring and retrieving external dependencies off of a context.
// It serves as a bridge between the server and any external dependencies that may need to be
// swapped out during testing. The idea is that this custom context can be passed to the main
// Handler function in tests.

type handlerContextKey int

const (
tableNameKey handlerContextKey = iota
dynamoClientKey
sendMessageHandlerKey
)

type SendMessageHandler func(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, connectionID string, message interface{}) error

type HandlerContext struct {
context.Context
}

func NewHandlerContext(parent context.Context) *HandlerContext {
return &HandlerContext{Context: parent}
}

func (c *HandlerContext) WithTableName(v *string) *HandlerContext {
c.Context = context.WithValue(c.Context, tableNameKey, v)
return c
}

func (c *HandlerContext) WithDynamoClient(v *dynamodb.DynamoDB) *HandlerContext {
c.Context = context.WithValue(c.Context, dynamoClientKey, v)
return c
}

func (c *HandlerContext) WithSendMessageHandler(v SendMessageHandler) *HandlerContext {
c.Context = context.WithValue(c.Context, sendMessageHandlerKey, v)
return c
}

var defaultTableName = aws.String("othelgo")

func getTableName(ctx context.Context) *string {
if v, ok := ctx.Value(tableNameKey).(*string); ok {
return v
}
return defaultTableName
}

var defaultDynamoClient = dynamodb.New(session.Must(session.NewSession(aws.NewConfig().WithRegion(os.Getenv("AWS_REGION")))))

func getDynamoClient(ctx context.Context) *dynamodb.DynamoDB {
if v, ok := ctx.Value(dynamoClientKey).(*dynamodb.DynamoDB); ok {
return v
}
return defaultDynamoClient
}

func defaultSendMessageHandler(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, connectionID string, message interface{}) error {
data, err := json.Marshal(message)
if err != nil {
return err
}

endpoint := fmt.Sprintf("https://%s/%s/", reqCtx.DomainName, reqCtx.Stage)
client := apigatewaymanagementapi.New(session.Must(session.NewSession(aws.NewConfig().WithEndpoint(endpoint))))

log.Printf("Sending message to connection %s", connectionID)

_, err = client.PostToConnectionWithContext(ctx, &apigatewaymanagementapi.PostToConnectionInput{
ConnectionId: &connectionID,
Data: data,
})

return err
}

func getSendMessageHandler(ctx context.Context) SendMessageHandler {
if v, ok := ctx.Value(sendMessageHandlerKey).(SendMessageHandler); ok {
return v
}
return defaultSendMessageHandler
}
28 changes: 10 additions & 18 deletions pkg/server/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@ import (
"context"
"encoding/json"
"log"
"os"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"

"github.com/armsnyder/othelgo/pkg/common"
)

var (
tableName = aws.String("othelgo")
connectionsKey = makeKey("connections")
boardKeyValue = "board"
connectionsAttribute = "connections"
Expand All @@ -32,14 +29,9 @@ type gameItem struct {
PlayerRaw int `json:"player"`
}

// DynamoClient is the DynamoDB client.
// It is the only connection between this package and DynamoDB.
// It is exported so that it can be overridden in tests.
var DynamoClient = dynamodb.New(session.Must(session.NewSession(aws.NewConfig().WithRegion(os.Getenv("AWS_REGION")))))

func getAllConnectionIDs(ctx context.Context) ([]string, error) {
output, err := DynamoClient.GetItemWithContext(ctx, &dynamodb.GetItemInput{
TableName: tableName,
output, err := getDynamoClient(ctx).GetItemWithContext(ctx, &dynamodb.GetItemInput{
TableName: getTableName(ctx),
Key: connectionsKey,
})
if err != nil {
Expand All @@ -55,8 +47,8 @@ func getAllConnectionIDs(ctx context.Context) ([]string, error) {
}

func saveConnection(ctx context.Context, connectionID string) error {
_, err := DynamoClient.UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
TableName: tableName,
_, err := getDynamoClient(ctx).UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
TableName: getTableName(ctx),
Key: connectionsKey,
UpdateExpression: aws.String("ADD #c :v"),
ExpressionAttributeNames: map[string]*string{
Expand All @@ -71,8 +63,8 @@ func saveConnection(ctx context.Context, connectionID string) error {
}

func forgetConnection(ctx context.Context, connectionID string) error {
_, err := DynamoClient.UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
TableName: tableName,
_, err := getDynamoClient(ctx).UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
TableName: getTableName(ctx),
Key: connectionsKey,
UpdateExpression: aws.String("DELETE #c :v"),
ExpressionAttributeNames: map[string]*string{
Expand All @@ -91,8 +83,8 @@ func loadGame(ctx context.Context) (gameItem, error) {

log.Println("Loading game")

output, err := DynamoClient.GetItemWithContext(ctx, &dynamodb.GetItemInput{
TableName: tableName,
output, err := getDynamoClient(ctx).GetItemWithContext(ctx, &dynamodb.GetItemInput{
TableName: getTableName(ctx),
Key: makeKey(boardKeyValue),
})
if err != nil {
Expand Down Expand Up @@ -131,8 +123,8 @@ func saveGame(ctx context.Context, game gameItem) error {
return err
}

_, err = DynamoClient.PutItemWithContext(ctx, &dynamodb.PutItemInput{
TableName: tableName,
_, err = getDynamoClient(ctx).PutItemWithContext(ctx, &dynamodb.PutItemInput{
TableName: getTableName(ctx),
Item: item,
})

Expand Down
32 changes: 2 additions & 30 deletions pkg/server/management_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,11 @@ package server

import (
"context"
"encoding/json"
"fmt"
"log"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigatewaymanagementapi"
"golang.org/x/sync/errgroup"
)

// SendMessage sends a message to a connected client using the API Gateway Management API.
// It is the only connection between this package and the API Gateway Management API.
// It is exported so that the behavior can be overridden in tests.
var SendMessage = func(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, connectionID string, message interface{}) error {
data, err := json.Marshal(message)
if err != nil {
return err
}

endpoint := fmt.Sprintf("https://%s/%s/", reqCtx.DomainName, reqCtx.Stage)
client := apigatewaymanagementapi.New(session.Must(session.NewSession(aws.NewConfig().WithEndpoint(endpoint))))

log.Printf("Sending message to connection %s", connectionID)

_, err = client.PostToConnectionWithContext(ctx, &apigatewaymanagementapi.PostToConnectionInput{
ConnectionId: &connectionID,
Data: data,
})

return err
}

func broadcastMessage(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, message interface{}) error {
connectionIDs, err := getAllConnectionIDs(ctx)
if err != nil {
Expand All @@ -55,11 +27,11 @@ func broadcastMessage(ctx context.Context, reqCtx events.APIGatewayWebsocketProx
}

func reply(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, message interface{}) error {
return SendMessage(ctx, reqCtx, reqCtx.ConnectionID, message)
return getSendMessageHandler(ctx)(ctx, reqCtx, reqCtx.ConnectionID, message)
}

func sendMessage(ctx context.Context, reqCtx events.APIGatewayWebsocketProxyRequestContext, connectionID string, message interface{}) func() error {
return func() error {
return SendMessage(ctx, reqCtx, connectionID, message)
return getSendMessageHandler(ctx)(ctx, reqCtx, connectionID, message)
}
}
Loading

0 comments on commit 72dda22

Please sign in to comment.