diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index 41807ab416e0..ba6ed793e6b0 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -20,10 +20,15 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" ) +const ( + apiVersion = "2020-11-05" +) + // Client is used to interact with the Azure Cosmos DB database service. type Client struct { endpoint string pipeline azruntime.Pipeline + gem *globalEndpointManager } // Endpoint used to create the client. @@ -36,7 +41,15 @@ func (c *Client) Endpoint() string { // cred - The credential used to authenticate with the cosmos service. // options - Optional Cosmos client options. Pass nil to accept default values. func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) { - return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), o)}, nil + preferredRegions := []string{} + if o != nil { + preferredRegions = o.PreferredRegions + } + gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0) + if err != nil { + return nil, err + } + return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), gem, o), gem: gem}, nil } // NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration. @@ -48,7 +61,16 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) ( if err != nil { return nil, err } - return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o)}, nil + preferredRegions := []string{} + if o != nil { + preferredRegions = o.PreferredRegions + } + gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0) + if err != nil { + return nil, err + } + + return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), gem, o), gem: gem}, nil } // NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration. @@ -87,7 +109,7 @@ func NewClientFromConnectionString(connectionString string, o *ClientOptions) (* return NewClientWithKey(endpoint, cred, o) } -func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline { +func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options *ClientOptions) azruntime.Pipeline { if options == nil { options = &ClientOptions{} } @@ -98,6 +120,7 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip &headerPolicies{ enableContentResponseOnWrite: options.EnableContentResponseOnWrite, }, + &globalEndpointManagerPolicy{gem: gem}, }, PerRetry: []policy.Policy{ authPolicy, @@ -106,6 +129,19 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip &options.ClientOptions) } +func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline { + if options == nil { + options = &ClientOptions{} + } + return azruntime.NewPipeline("azcosmos", serviceLibVersion, + azruntime.PipelineOptions{ + PerRetry: []policy.Policy{ + authPolicy, + }, + }, + &options.ClientOptions) +} + func createScopeFromEndpoint(endpoint string) ([]string, error) { u, err := url.Parse(endpoint) if err != nil { @@ -394,7 +430,7 @@ func (c *Client) createRequest( } req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) - req.Raw().Header.Set(headerXmsVersion, "2020-11-05") + req.Raw().Header.Set(headerXmsVersion, apiVersion) req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue) req.SetOperationValue(operationContext) diff --git a/sdk/data/azcosmos/cosmos_client_test.go b/sdk/data/azcosmos/cosmos_client_test.go index d93f4ca93d7c..b6ceea6de18e 100644 --- a/sdk/data/azcosmos/cosmos_client_test.go +++ b/sdk/data/azcosmos/cosmos_client_test.go @@ -254,8 +254,8 @@ func TestCreateRequest(t *testing.T) { t.Errorf("Expected %v, but got %v", "", req.Raw().Header.Get(headerXmsDate)) } - if req.Raw().Header.Get(headerXmsVersion) != "2020-11-05" { - t.Errorf("Expected %v, but got %v", "2020-11-05", req.Raw().Header.Get(headerXmsVersion)) + if req.Raw().Header.Get(headerXmsVersion) != apiVersion { + t.Errorf("Expected %v, but got %v", apiVersion, req.Raw().Header.Get(headerXmsVersion)) } if req.Raw().Header.Get(cosmosHeaderSDKSupportedCapabilities) != supportedCapabilitiesHeaderValue { diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go index 82770d999ec9..5383864a4407 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go @@ -17,7 +17,8 @@ import ( const defaultUnavailableLocationRefreshInterval = 5 * time.Minute type globalEndpointManager struct { - client *Client + clientEndpoint string + pipeline azruntime.Pipeline preferredLocations []string locationCache *locationCache refreshTimeInterval time.Duration @@ -25,8 +26,8 @@ type globalEndpointManager struct { lastUpdateTime time.Time } -func newGlobalEndpointManager(client *Client, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) { - endpoint, err := url.Parse(client.endpoint) +func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) { + endpoint, err := url.Parse(clientEndpoint) if err != nil { return &globalEndpointManager{}, err } @@ -36,7 +37,8 @@ func newGlobalEndpointManager(client *Client, preferredLocations []string, refre } gem := &globalEndpointManager{ - client: client, + clientEndpoint: clientEndpoint, + pipeline: pipeline, preferredLocations: preferredLocations, locationCache: newLocationCache(preferredLocations, *endpoint), refreshTimeInterval: refreshTimeInterval, @@ -110,24 +112,34 @@ func (gem *globalEndpointManager) GetAccountProperties(ctx context.Context) (acc resourceAddress: "", } - path, err := generatePathForNameBased(resourceTypeDatabaseAccount, "", false) + ctxt, cancel := context.WithTimeout(ctx, 60*time.Second) + defer cancel() + req, err := azruntime.NewRequest(ctxt, http.MethodGet, gem.clientEndpoint) if err != nil { - return accountProperties{}, fmt.Errorf("failed to generate path for name-based request: %v", err) + return accountProperties{}, err } - ctx, cancel := context.WithTimeout(ctx, 60*time.Second) - azResponse, err := gem.client.sendGetRequest(path, ctx, operationContext, nil, nil) - cancel() + req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Raw().Header.Set(headerXmsVersion, apiVersion) + req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue) + + req.SetOperationValue(operationContext) + + azResponse, err := gem.pipeline.Do(req) if err != nil { - return accountProperties{}, fmt.Errorf("failed to retrieve account properties: %v", err) + return accountProperties{}, err } - properties, err := newAccountProperties(azResponse) - if err != nil { - return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err) + successResponse := (azResponse.StatusCode >= 200 && azResponse.StatusCode < 300) + if successResponse { + properties, err := newAccountProperties(azResponse) + if err != nil { + return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err) + } + return properties, nil } - return properties, nil + return accountProperties{}, newCosmosError(azResponse) } func newAccountProperties(azResponse *http.Response) (accountProperties, error) { diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go new file mode 100644 index 000000000000..9265c3522663 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type globalEndpointManagerPolicy struct { + gem *globalEndpointManager +} + +func (p *globalEndpointManagerPolicy) Do(req *policy.Request) (*http.Response, error) { + shouldRefresh := p.gem.ShouldRefresh() + if shouldRefresh { + go func() { + _ = p.gem.Update(context.Background()) + }() + } + return req.Next() +} diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go index 4c0089e55c74..0273aa6c83d4 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go @@ -34,11 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - client := &Client{endpoint: srv.URL(), pipeline: pl} - - preferredRegions := []string{"West US", "Central US"} - - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) assert.NoError(t, err) writeEndpoints, err := gem.GetWriteEndpoints() @@ -50,6 +46,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) { expectedWriteEndpoints := []url.URL{ *serverEndpoint, } + assert.Equal(t, expectedWriteEndpoints, writeEndpoints) } @@ -60,11 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - client := &Client{endpoint: srv.URL(), pipeline: pl} - - preferredRegions := []string{"West US", "Central US"} - - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) assert.NoError(t, err) readEndpoints, err := gem.GetReadEndpoints() @@ -88,12 +81,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) { client := &Client{endpoint: srv.URL(), pipeline: pl} - preferredRegions := []string{"West US", "Central US"} - - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + endpoint, err := url.Parse(client.endpoint) assert.NoError(t, err) - endpoint, err := url.Parse(client.endpoint) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) assert.NoError(t, err) err = gem.MarkEndpointUnavailableForRead(*endpoint) @@ -112,12 +103,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) { client := &Client{endpoint: srv.URL(), pipeline: pl} - preferredRegions := []string{"West US", "Central US"} - - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + endpoint, err := url.Parse(client.endpoint) assert.NoError(t, err) - endpoint, err := url.Parse(client.endpoint) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) assert.NoError(t, err) err = gem.MarkEndpointUnavailableForWrite(*endpoint) @@ -130,7 +119,6 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) { func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) { srv, close := mock.NewTLSServer() defer close() - srv.SetResponse(mock.WithStatusCode(http.StatusOK)) westRegion := accountRegion{ Name: "West US", @@ -144,19 +132,17 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) { } jsonString, err := json.Marshal(properties) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + + srv.SetResponse(mock.WithStatusCode(200)) srv.SetResponse(mock.WithBody(jsonString)) pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - client := &Client{endpoint: srv.URL(), pipeline: pl} - - gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Minute) + serverEndpoint, err := url.Parse(srv.URL()) assert.NoError(t, err) - serverEndpoint, err := url.Parse(srv.URL()) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute) assert.NoError(t, err) err = gem.Update(context.Background()) @@ -175,11 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - client := &Client{endpoint: srv.URL(), pipeline: pl} - - preferredRegions := []string{"West US", "Central US"} - - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) assert.NoError(t, err) accountProps, err := gem.GetAccountProperties(context.Background()) @@ -212,13 +194,13 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) { mockLc.useMultipleWriteLocations = true mockGem := globalEndpointManager{ - client: client, + clientEndpoint: client.endpoint, preferredLocations: preferredRegions, locationCache: mockLc, refreshTimeInterval: 5 * time.Minute, } - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute) assert.NoError(t, err) // Multiple locations should be false for default GEM @@ -254,9 +236,8 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) { srv.SetResponse(mock.WithBody(jsonString)) pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv}) - client := &Client{endpoint: srv.URL(), pipeline: pl} - gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Second) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second) assert.NoError(t, err) // Call update concurrently and see how many times the policy gets called diff --git a/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go b/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go index 3409758c4f42..be2a47ca9a04 100644 --- a/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go +++ b/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go @@ -19,7 +19,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) { preferredRegions := []string{} emulatorRegion := accountRegion{Name: emulatorRegionName, Endpoint: "https://127.0.0.1:8081/"} - gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute) assert.NoError(t, err) accountProps, err := gem.GetAccountProperties(context.Background()) @@ -61,7 +61,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) { assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation) assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation) - //update and assert available locations are now populated in location cache + // Run Update() and assert available locations are now populated in location cache err = gem.Update(context.Background()) assert.NoError(t, err) locationInfo = gem.locationCache.locationInfo @@ -73,3 +73,34 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) { assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1) assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1) } + +func TestGlobalEndpointManagerPolicyEmulator(t *testing.T) { + emulatorTests := newEmulatorTests(t) + client := emulatorTests.getClient(t) + emulatorRegionName := "South Central US" + + // Assert location cache is not populated until update() is called within the policy + locationInfo := client.gem.locationCache.locationInfo + availableLocation := []string{} + availableEndpointsByLocation := map[string]url.URL{} + + assert.Equal(t, locationInfo.availReadLocations, availableLocation) + assert.Equal(t, locationInfo.availWriteLocations, availableLocation) + assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation) + assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation) + + // Assert that information gets populated by the gem policy after running an http request (read item) + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err := container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + assert.Error(t, err) + + locationInfo = client.gem.locationCache.locationInfo + + assert.Equal(t, len(locationInfo.availReadLocations), len(availableLocation)+1) + assert.Equal(t, len(locationInfo.availWriteLocations), len(availableLocation)+1) + assert.Equal(t, locationInfo.availWriteLocations[0], emulatorRegionName) + assert.Equal(t, locationInfo.availReadLocations[0], emulatorRegionName) + assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1) + assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1) +} diff --git a/sdk/data/azcosmos/shared_key_credential_test.go b/sdk/data/azcosmos/shared_key_credential_test.go index 4ac2d3f83570..7e632f563561 100644 --- a/sdk/data/azcosmos/shared_key_credential_test.go +++ b/sdk/data/azcosmos/shared_key_credential_test.go @@ -69,7 +69,7 @@ func Test_buildCanonicalizedAuthHeaderFromRequest(t *testing.T) { } req.Raw().Header.Set(headerXmsDate, xmsDate) - req.Raw().Header.Set(headerXmsVersion, "2020-11-05") + req.Raw().Header.Set(headerXmsVersion, apiVersion) req.SetOperationValue(operationContext) authHeader, _ := cred.buildCanonicalizedAuthHeaderFromRequest(req) @@ -102,7 +102,7 @@ func Test_buildCanonicalizedAuthHeaderFromRequestWithRid(t *testing.T) { } req.Raw().Header.Set(headerXmsDate, xmsDate) - req.Raw().Header.Set(headerXmsVersion, "2020-11-05") + req.Raw().Header.Set(headerXmsVersion, apiVersion) req.SetOperationValue(operationContext) authHeader, _ := cred.buildCanonicalizedAuthHeaderFromRequest(req) @@ -135,7 +135,7 @@ func Test_buildCanonicalizedAuthHeaderFromRequestWithEscapedCharacters(t *testin } req.Raw().Header.Set(headerXmsDate, xmsDate) - req.Raw().Header.Set(headerXmsVersion, "2020-11-05") + req.Raw().Header.Set(headerXmsVersion, apiVersion) req.SetOperationValue(operationContext) authHeader, _ := cred.buildCanonicalizedAuthHeaderFromRequest(req)