Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] Adds global endpoint manager policy and links GEM to client #22223

Merged
merged 13 commits into from
Jan 17, 2024
32 changes: 29 additions & 3 deletions sdk/data/azcosmos/cosmos_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type Client struct {
endpoint string
pipeline azruntime.Pipeline
gem *globalEndpointManager
}

// Endpoint used to create the client.
Expand All @@ -36,7 +37,12 @@ func (c *Client) Endpoint() string {
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), o)}, nil
//need to pass in preferredRegions from options here once those changes are merged
ealsur marked this conversation as resolved.
Show resolved Hide resolved
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), []string{}, 0)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), gem, o), gem: gem}, nil
}

// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
Expand All @@ -48,7 +54,13 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o)}, nil
//need to pass in preferredRegions from options here once those changes are merged
ealsur marked this conversation as resolved.
Show resolved Hide resolved
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), []string{}, 0)
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), gem, o), gem: gem}, nil
}

// NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration.
Expand Down Expand Up @@ -87,7 +99,7 @@ func NewClientFromConnectionString(connectionString string, o *ClientOptions) (*
return NewClientWithKey(endpoint, cred, o)
}

func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
Expand All @@ -98,7 +110,21 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip
&headerPolicies{
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
},
&globalEndpointManagerPolicy{gem: gem},
},
PerRetry: []policy.Policy{
authPolicy,
},
},
&options.ClientOptions)
}

func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
return azruntime.NewPipeline("azcosmos", serviceLibVersion,
azruntime.PipelineOptions{
PerRetry: []policy.Policy{
authPolicy,
},
Expand Down
44 changes: 33 additions & 11 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ import (
const defaultUnavailableLocationRefreshInterval = 5 * time.Minute

type globalEndpointManager struct {
client *Client
clientEndpoint string
pipeline azruntime.Pipeline
preferredLocations []string
locationCache *locationCache
refreshTimeInterval time.Duration
gemMutex sync.Mutex
lastUpdateTime time.Time
}

func newGlobalEndpointManager(client *Client, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(client.endpoint)
func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(clientEndpoint)
if err != nil {
return &globalEndpointManager{}, err
}
Expand All @@ -36,7 +37,8 @@ func newGlobalEndpointManager(client *Client, preferredLocations []string, refre
}

gem := &globalEndpointManager{
client: client,
clientEndpoint: clientEndpoint,
pipeline: pipeline,
preferredLocations: preferredLocations,
locationCache: newLocationCache(preferredLocations, *endpoint),
refreshTimeInterval: refreshTimeInterval,
Expand Down Expand Up @@ -115,19 +117,39 @@ func (gem *globalEndpointManager) GetAccountProperties(ctx context.Context) (acc
return accountProperties{}, fmt.Errorf("failed to generate path for name-based request: %v", err)
}

ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
azResponse, err := gem.client.sendGetRequest(path, ctx, operationContext, nil, nil)
cancel()
finalURL := gem.clientEndpoint
if path != "" {
finalURL = azruntime.JoinPaths(gem.clientEndpoint, path)
}
simorenoh marked this conversation as resolved.
Show resolved Hide resolved

ctxt, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
req, err := azruntime.NewRequest(ctxt, http.MethodGet, finalURL)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to retrieve account properties: %v", err)
return accountProperties{}, err
}

properties, err := newAccountProperties(azResponse)
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, "2020-11-05")
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)

req.SetOperationValue(operationContext)

azResponse, err := gem.pipeline.Do(req)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
return accountProperties{}, err
}

return properties, nil
successResponse := (azResponse.StatusCode >= 200 && azResponse.StatusCode < 300) || azResponse.StatusCode == 304
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
if successResponse {
properties, err := newAccountProperties(azResponse)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
}
return properties, nil
}

return accountProperties{}, newCosmosError(azResponse)
}

func newAccountProperties(azResponse *http.Response) (accountProperties, error) {
Expand Down
25 changes: 25 additions & 0 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"context"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

type globalEndpointManagerPolicy struct {
gem *globalEndpointManager
}

func (p *globalEndpointManagerPolicy) Do(req *policy.Request) (*http.Response, error) {
simorenoh marked this conversation as resolved.
Show resolved Hide resolved
shouldRefresh := p.gem.ShouldRefresh()
if shouldRefresh {
go func() {
_ = p.gem.Update(context.Background())
}()
}
return req.Next()
}
51 changes: 16 additions & 35 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

writeEndpoints, err := gem.GetWriteEndpoints()
Expand All @@ -50,6 +46,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {
expectedWriteEndpoints := []url.URL{
*serverEndpoint,
}

assert.Equal(t, expectedWriteEndpoints, writeEndpoints)
}

Expand All @@ -60,11 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

readEndpoints, err := gem.GetReadEndpoints()
Expand All @@ -88,12 +81,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForRead(*endpoint)
Expand All @@ -112,12 +103,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForWrite(*endpoint)
Expand All @@ -130,7 +119,6 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {
func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))

westRegion := accountRegion{
Name: "West US",
Expand All @@ -144,19 +132,17 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
}

jsonString, err := json.Marshal(properties)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)

srv.SetResponse(mock.WithStatusCode(200))
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Minute)
serverEndpoint, err := url.Parse(srv.URL())
assert.NoError(t, err)

serverEndpoint, err := url.Parse(srv.URL())
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

err = gem.Update(context.Background())
Expand All @@ -175,11 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -212,13 +194,13 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) {
mockLc.useMultipleWriteLocations = true

mockGem := globalEndpointManager{
client: client,
clientEndpoint: client.endpoint,
preferredLocations: preferredRegions,
locationCache: mockLc,
refreshTimeInterval: 5 * time.Minute,
}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

// Multiple locations should be false for default GEM
Expand Down Expand Up @@ -254,9 +236,8 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) {
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv})
client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Second)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second)
assert.NoError(t, err)

// Call update concurrently and see how many times the policy gets called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
preferredRegions := []string{}
emulatorRegion := accountRegion{Name: emulatorRegionName, Endpoint: "https://127.0.0.1:8081/"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down