Skip to content

Commit

Permalink
Fix SRA auth trailing checksum retry bug (#2438)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaiahvita authored Jan 3, 2024
1 parent efbc5aa commit 0f8ad11
Show file tree
Hide file tree
Showing 30 changed files with 325 additions and 16 deletions.
6 changes: 4 additions & 2 deletions aws/retry/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,12 @@ func AddRetryMiddlewares(stack *smithymiddle.Stack, options AddRetryMiddlewaresO
middleware.LogAttempts = options.LogRetryAttempts
})

if err := stack.Finalize.Add(attempt, smithymiddle.After); err != nil {
// index retry to before signing, if signing exists
if err := stack.Finalize.Insert(attempt, "Signing", smithymiddle.Before); err != nil {
return err
}
if err := stack.Finalize.Add(&MetricsHeader{}, smithymiddle.After); err != nil {

if err := stack.Finalize.Insert(&MetricsHeader{}, attempt.ID(), smithymiddle.After); err != nil {
return err
}
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,14 @@ private void writeConvertToPresignMiddleware(
if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok {
stack.Finalize.Remove(($1P)(nil).ID())
}""", SdkGoTypes.ServiceInternal.AcceptEncoding.DisableGzip);
writer.write("""
if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok {
stack.Finalize.Remove(($1P)(nil).ID())
}""", SdkGoTypes.Aws.Retry.Attempt);
writer.write("""
if _, ok := stack.Finalize.Get(($1P)(nil).ID()); ok {
stack.Finalize.Remove(($1P)(nil).ID())
}""", SdkGoTypes.Aws.Retry.MetricsHeader);
writer.write("stack.Deserialize.Clear()");
writer.write("stack.Build.Remove(($P)(nil).ID())", requestInvocationID);
writer.write("stack.Build.Remove($S)", "UserAgent");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ public static final class Aws {
public static final Symbol IsCredentialsProvider = AwsGoDependency.AWS_CORE.valueSymbol("IsCredentialsProvider");
public static final Symbol AnonymousCredentials = AwsGoDependency.AWS_CORE.pointableSymbol("AnonymousCredentials");


public static final class Middleware {
public static final Symbol GetRequiresLegacyEndpoints = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetRequiresLegacyEndpoints");
public static final Symbol GetSigningName = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetSigningName");
public static final Symbol GetSigningRegion = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetSigningRegion");
public static final Symbol SetSigningName = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("SetSigningName");
public static final Symbol SetSigningRegion = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("SetSigningRegion");
}


public static final class Retry {
public static final Symbol Attempt = AwsGoDependency.AWS_RETRY.pointableSymbol("Attempt");
public static final Symbol MetricsHeader = AwsGoDependency.AWS_RETRY.pointableSymbol("MetricsHeader");
}
}

public static final class Internal {
Expand Down
48 changes: 48 additions & 0 deletions credentials/endpointcreds/internal/client/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package client

import (
"context"
"github.com/aws/smithy-go/middleware"
)

type getIdentityMiddleware struct {
options Options
}

func (*getIdentityMiddleware) ID() string {
return "GetIdentity"
}

func (m *getIdentityMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}

type signRequestMiddleware struct {
}

func (*signRequestMiddleware) ID() string {
return "Signing"
}

func (m *signRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}

type resolveAuthSchemeMiddleware struct {
operation string
options Options
}

func (*resolveAuthSchemeMiddleware) ID() string {
return "ResolveAuthScheme"
}

func (m *resolveAuthSchemeMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}
1 change: 1 addition & 0 deletions credentials/endpointcreds/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ func (c *Client) GetCredentials(ctx context.Context, params *GetCredentialsInput
stack.Serialize.Add(&serializeOpGetCredential{}, smithymiddleware.After)
stack.Build.Add(&buildEndpoint{Endpoint: options.Endpoint}, smithymiddleware.After)
stack.Deserialize.Add(&deserializeOpGetCredential{}, smithymiddleware.After)
addProtocolFinalizerMiddlewares(stack, options, "GetCredentials")
retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{Retryer: options.Retryer})
middleware.AddSDKAgentKey(middleware.FeatureMetadata, ServiceID)
smithyhttp.AddErrorCloseResponseBodyMiddleware(stack)
Expand Down
20 changes: 20 additions & 0 deletions credentials/endpointcreds/internal/client/endpoints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package client

import (
"context"
"github.com/aws/smithy-go/middleware"
)

type resolveEndpointV2Middleware struct {
options Options
}

func (*resolveEndpointV2Middleware) ID() string {
return "ResolveEndpointV2"
}

func (m *resolveEndpointV2Middleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}
16 changes: 16 additions & 0 deletions credentials/endpointcreds/internal/client/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,19 @@ func stof(code int) smithy.ErrorFault {
}
return smithy.FaultClient
}

func addProtocolFinalizerMiddlewares(stack *smithymiddleware.Stack, options Options, operation string) error {
if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, smithymiddleware.Before); err != nil {
return fmt.Errorf("add ResolveAuthScheme: %w", err)
}
if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", smithymiddleware.After); err != nil {
return fmt.Errorf("add GetIdentity: %w", err)
}
if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", smithymiddleware.After); err != nil {
return fmt.Errorf("add ResolveEndpointV2: %w", err)
}
if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", smithymiddleware.After); err != nil {
return fmt.Errorf("add Signing: %w", err)
}
return nil
}
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetDynamicData.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type GetDynamicDataOutput struct {
func addGetDynamicDataMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetDynamicData",
buildGetDynamicDataPath,
buildGetDynamicDataOutput)
}
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetIAMInfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type GetIAMInfoOutput struct {
func addGetIAMInfoMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetIAMInfo",
buildGetIAMInfoPath,
buildGetIAMInfoOutput,
)
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetInstanceIdentityDocument.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type GetInstanceIdentityDocumentOutput struct {
func addGetInstanceIdentityDocumentMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetInstanceIdentityDocument",
buildGetInstanceIdentityDocumentPath,
buildGetInstanceIdentityDocumentOutput,
)
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetMetadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type GetMetadataOutput struct {
func addGetMetadataMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetMetadata",
buildGetMetadataPath,
buildGetMetadataOutput)
}
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetRegion.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type GetRegionOutput struct {
func addGetRegionMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetRegion",
buildGetInstanceIdentityDocumentPath,
buildGetRegionOutput,
)
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func addGetTokenMiddleware(stack *middleware.Stack, options Options) error {
err := addRequestMiddleware(stack,
options,
"PUT",
"GetToken",
buildGetTokenPath,
buildGetTokenOutput)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions feature/ec2/imds/api_op_GetUserData.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type GetUserDataOutput struct {
func addGetUserDataMiddleware(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
options,
"GetUserData",
buildGetUserDataPath,
buildGetUserDataOutput)
}
Expand Down
48 changes: 48 additions & 0 deletions feature/ec2/imds/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package imds

import (
"context"
"github.com/aws/smithy-go/middleware"
)

type getIdentityMiddleware struct {
options Options
}

func (*getIdentityMiddleware) ID() string {
return "GetIdentity"
}

func (m *getIdentityMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}

type signRequestMiddleware struct {
}

func (*signRequestMiddleware) ID() string {
return "Signing"
}

func (m *signRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}

type resolveAuthSchemeMiddleware struct {
operation string
options Options
}

func (*resolveAuthSchemeMiddleware) ID() string {
return "ResolveAuthScheme"
}

func (m *resolveAuthSchemeMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}
20 changes: 20 additions & 0 deletions feature/ec2/imds/endpoints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package imds

import (
"context"
"github.com/aws/smithy-go/middleware"
)

type resolveEndpointV2Middleware struct {
options Options
}

func (*resolveEndpointV2Middleware) ID() string {
return "ResolveEndpointV2"
}

func (m *resolveEndpointV2Middleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
return next.HandleFinalize(ctx, in)
}
24 changes: 23 additions & 1 deletion feature/ec2/imds/request_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (

func addAPIRequestMiddleware(stack *middleware.Stack,
options Options,
operation string,
getPath func(interface{}) (string, error),
getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
err = addRequestMiddleware(stack, options, "GET", getPath, getOutput)
err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
if err != nil {
return err
}
Expand All @@ -44,6 +45,7 @@ func addAPIRequestMiddleware(stack *middleware.Stack,
func addRequestMiddleware(stack *middleware.Stack,
options Options,
method string,
operation string,
getPath func(interface{}) (string, error),
getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
Expand Down Expand Up @@ -101,6 +103,10 @@ func addRequestMiddleware(stack *middleware.Stack,
return err
}

if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
return fmt.Errorf("add protocol finalizers: %w", err)
}

// Retry support
return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
Retryer: options.Retryer,
Expand Down Expand Up @@ -283,3 +289,19 @@ func appendURIPath(base, add string) string {
}
return reqPath
}

func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
return fmt.Errorf("add ResolveAuthScheme: %w", err)
}
if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
return fmt.Errorf("add GetIdentity: %w", err)
}
if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
return fmt.Errorf("add ResolveEndpointV2: %w", err)
}
if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
return fmt.Errorf("add Signing: %w", err)
}
return nil
}
12 changes: 11 additions & 1 deletion feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestAddRequestMiddleware(t *testing.T) {
"api request": {
AddMiddleware: func(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack, options,
"TestRequest",
func(interface{}) (string, error) {
return "/mockPath", nil
},
Expand All @@ -53,9 +54,13 @@ func TestAddRequestMiddleware(t *testing.T) {
"UserAgent",
},
ExpectFinalize: []string{
"ResolveAuthScheme",
"GetIdentity",
"ResolveEndpointV2",
"Retry",
"APITokenProvider",
"RetryMetricsHeader",
"Signing",
},
ExpectDeserialize: []string{
"APITokenProvider",
Expand All @@ -66,7 +71,7 @@ func TestAddRequestMiddleware(t *testing.T) {

"base request": {
AddMiddleware: func(stack *middleware.Stack, options Options) error {
return addRequestMiddleware(stack, options, "POST",
return addRequestMiddleware(stack, options, "POST", "TestRequest",
func(interface{}) (string, error) {
return "/mockPath", nil
},
Expand All @@ -87,8 +92,12 @@ func TestAddRequestMiddleware(t *testing.T) {
"UserAgent",
},
ExpectFinalize: []string{
"ResolveAuthScheme",
"GetIdentity",
"ResolveEndpointV2",
"Retry",
"RetryMetricsHeader",
"Signing",
},
ExpectDeserialize: []string{
"OperationDeserializer",
Expand Down Expand Up @@ -590,6 +599,7 @@ func TestRequestGetToken(t *testing.T) {
func(stack *middleware.Stack, options Options) error {
return addAPIRequestMiddleware(stack,
client.options.Copy(),
"TestRequest",
func(interface{}) (string, error) {
return "/latest/foo", nil
},
Expand Down
6 changes: 6 additions & 0 deletions service/docdb/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0f8ad11

Please sign in to comment.