Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Rate Limit handling #1356

Merged
merged 7 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 60 additions & 136 deletions okta/internal/apimutex/apimutex.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,12 @@ package apimutex

import (
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
)

const (
APPS_KEY = "apps"
APPID_KEY = "app-id"
CAS_KEY = "cas-id"
CLIENTS_KEY = "clients"
DEVICES_KEY = "devices"
EVENTS_KEY = "events"
GROUPS_KEY = "groups"
GROUPID_KEY = "group-id"
LOGS_KEY = "logs"
USERS_KEY = "users"
USERID_KEY = "user-id"
USERME_KEY = "user-me"
USERIDGET_KEY = "user-id-get"
OTHER_KEY = "other"
)

// APIMutex synchronizes keeping account of current known rate limit values
// from Okta management endpoints. See:
// https://developer.okta.com/docs/reference/rl-global-mgmt/
Expand All @@ -35,8 +17,9 @@ const (
// react appropriately.
type APIMutex struct {
lock sync.Mutex
status map[string]*APIStatus
capacity int
status map[string]*APIStatus
buckets map[string]string
}

// APIStatus is used to hold rate limit information from Okta's API, see:
Expand All @@ -45,35 +28,22 @@ type APIStatus struct {
limit int
remaining int
reset int64 // UTC epoch time in seconds
class string
}

// NewAPIMutex returns a new api mutex object that represents untilized
// capacity under the specified capacity percentage.
func NewAPIMutex(capacity int) (*APIMutex, error) {
if capacity < 1 || capacity > 100 {
return nil, fmt.Errorf("expecting capacity as whole number > 0 and <= 100, was %d", capacity)
}
status := map[string]*APIStatus{
APPS_KEY: {class: APPS_KEY},
APPID_KEY: {class: APPID_KEY},
CAS_KEY: {class: CAS_KEY},
CLIENTS_KEY: {class: CLIENTS_KEY},
DEVICES_KEY: {class: DEVICES_KEY},
LOGS_KEY: {class: LOGS_KEY},
EVENTS_KEY: {class: EVENTS_KEY},
GROUPS_KEY: {class: GROUPS_KEY},
GROUPID_KEY: {class: GROUPID_KEY},
OTHER_KEY: {class: OTHER_KEY},
USERS_KEY: {class: USERS_KEY},
USERID_KEY: {class: USERID_KEY},
USERME_KEY: {class: USERME_KEY},
USERIDGET_KEY: {class: USERIDGET_KEY},
}
return &APIMutex{
rootStatus := &APIStatus{}
mutex := &APIMutex{
capacity: capacity,
status: status,
}, nil
status: map[string]*APIStatus{
"/": rootStatus,
},
buckets: map[string]string{},
}
mutex.initRateLimitLookup()

return mutex, nil
}

// HasCapacity approximates if there is capacity below the api mutex's maximum
Expand All @@ -98,9 +68,7 @@ func (m *APIMutex) Update(method, endPoint string, limit, remaining int, reset i
m.lock.Lock()
defer m.lock.Unlock()

key := m.normalizeKey(method, endPoint)
status := m.status[key]

status := m.get(method, endPoint)
if reset > status.reset {
// reset value greater than current reset implies we are in a new Okta API
// one minute window. set/reset values.
Expand All @@ -120,97 +88,30 @@ func (m *APIMutex) Update(method, endPoint string, limit, remaining int, reset i
}
}

// Status return the APIStatus for the given class of endpoint.
// Status Returns the APIStatus for the given method + endpoint combination.
func (m *APIMutex) Status(method, endPoint string) *APIStatus {
return m.get(method, endPoint)
}

var (
reAppId = regexp.MustCompile(`/api/v1/apps/[^/]+[/\w]*$`)
reGroupId = regexp.MustCompile("/api/v1/groups/[^/]+$")
reUserId = regexp.MustCompile("/api/v1/users/[^/]+$")
)
// Class Returns the api endpoint class.
func (m *APIMutex) Class(method, endPoint string) string {
path := reOktaID.ReplaceAllString(endPoint, "ID")
return m.normalizedKey(method, path)
}

func (m *APIMutex) normalizeKey(method, endPoint string) string {
// Okta internal: see rate-limit-mappings-CLASSIC-DEFAULT.txt file in core
// repo. It corresponds to:
// https://developer.okta.com/docs/reference/rl-best-practices/
//
// TODO: API rate limits can be overwritten by the org admin, we should come
// up with a way to accommodate for that. Perhaps caching an APIStatus
// struct on the SHA of http method + URI path.

getPutDelete := (http.MethodGet == method) ||
(http.MethodPut == method) ||
(http.MethodDelete == method)
postPutDelete := (http.MethodPost == method) ||
(http.MethodPut == method) ||
(http.MethodDelete == method)
getPostPutDelete := (http.MethodGet == method) ||
(http.MethodPost == method) ||
(http.MethodPut == method) ||
(http.MethodDelete == method)
var result string

switch {
// 1. [GET|POST|PUT|DELETE] /api/v1/apps/${id}
case reAppId.MatchString(endPoint) && getPostPutDelete:
result = APPID_KEY

// 2. starts with /api/v1/apps
case strings.HasPrefix(endPoint, "/api/v1/apps"):
result = APPS_KEY

// 3. [GET|PUT|DELETE] /api/v1/groups/${id}
case reGroupId.MatchString(endPoint) && getPutDelete:
result = GROUPID_KEY

// 4. starts with /api/v1/groups
case strings.HasPrefix(endPoint, "/api/v1/groups"):
result = GROUPS_KEY

// 5. GET /api/v1/users/me
// NOTE: this is not documented in the devex docs
case endPoint == "/api/v1/users/me" && method == http.MethodGet:
result = USERME_KEY

// 6. [POST|PUT|DELETE] /api/v1/users/${id}
case reUserId.MatchString(endPoint) && postPutDelete:
result = USERID_KEY

// 7. [GET] /api/v1/users/${idOrLogin}
case reUserId.MatchString(endPoint) && method == http.MethodGet:
result = USERIDGET_KEY

// 8. starts with /api/v1/users
case strings.HasPrefix(endPoint, "/api/v1/users"):
result = USERS_KEY

// 9. GET /api/v1/logs
case endPoint == "/api/v1/logs" && method == http.MethodGet:
result = LOGS_KEY

// 10. GET /api/v1/events
case endPoint == "/api/v1/events" && method == http.MethodGet:
result = EVENTS_KEY

// 11. GET /oauth2/v1/clients
case endPoint == "/oauth2/v1/clients" && method == http.MethodGet:
result = CLIENTS_KEY

// 12. GET /api/v1/certificateAuthorities
case endPoint == "/api/v1/certificateAuthorities" && method == http.MethodGet:
result = CAS_KEY

// 13. GET /api/v1/devices
case endPoint == "/api/v1/devices" && method == http.MethodGet:
result = DEVICES_KEY

// 14. GET /api/v1
default:
result = "other"
// Bucket Returns the rate limit bucket the api endpoint falls into.
func (m *APIMutex) Bucket(method, endPoint string) string {
path := reOktaID.ReplaceAllString(endPoint, "ID")
key := m.normalizedKey(method, path)
bucket, ok := m.buckets[key]
if !ok {
return "/"
}
return result
return bucket
}

func (m *APIMutex) normalizedKey(method, endPoint string) string {
return fmt.Sprintf("%s %s", method, endPoint)
}

// Reset returns the current reset value of the api status object.
Expand All @@ -228,12 +129,35 @@ func (s *APIStatus) Remaining() int {
return s.remaining
}

// Class returns the api endpoint class for this status.
func (s *APIStatus) Class() string {
return s.class
}
var (
reOktaID = regexp.MustCompile(`[\w]{20}`)
)

func (m *APIMutex) get(method, endPoint string) *APIStatus {
key := m.normalizeKey(method, endPoint)
return m.status[key]
// the important point here is the replace all is performing this
// transformation for the bucket lookup /api/v1/users/abcdefghij0123456789
// to /api/v1/users/ID
path := reOktaID.ReplaceAllString(endPoint, "ID")
key := m.normalizedKey(method, path)
bucket, ok := m.buckets[key]
if !ok {
return m.status["/"]
}
return m.status[bucket]
}

func (m *APIMutex) initRateLimitLookup() {
for _, line := range rateLimitLines {
vals := strings.Split(line, " ")
path := vals[0]
method := vals[1]
bucket := vals[2]

key := m.normalizedKey(method, path)
m.buckets[key] = bucket

if _, ok := m.status[bucket]; !ok {
m.status[bucket] = &APIStatus{}
}
}
}
129 changes: 0 additions & 129 deletions okta/internal/apimutex/apimutex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,135 +81,6 @@ func TestUpdate(t *testing.T) {
}
}

func TestGet(t *testing.T) {
amu, err := NewAPIMutex(100)
if err != nil {
t.Fatalf("api mutex constructor had error %+v", err)
}
if len(amu.status) != 14 {
t.Fatalf("amu status map should sized 14 but was sized %d", len(amu.status))
}
keys := []string{
"users",
"user-id",
"user-me",
"user-id-get",
"apps",
"app-id",
"groups",
"group-id",
"cas-id",
"clients",
"devices",
"events",
"logs",
"other",
}
for _, key := range keys {
if _, found := amu.status[key]; !found {
t.Fatalf("amu should have status for key %q", key)
}
}
}

func TestNormalizeKey(t *testing.T) {
// Attempts to cover the rules listed at:
// https://developer.okta.com/docs/reference/rl-global-mgmt/
tests := []struct {
method string
endPoint string
expected string
}{
// 1. [GET|POST|PUT|DELETE] /api/v1/apps/${id}
{method: http.MethodGet, endPoint: "/api/v1/apps/TESTID", expected: "app-id"},
{method: http.MethodPut, endPoint: "/api/v1/apps/TESTID", expected: "app-id"},
{method: http.MethodDelete, endPoint: "/api/v1/apps/TESTID", expected: "app-id"},
{method: http.MethodGet, endPoint: "/api/v1/apps/TESTID/users", expected: "app-id"},
{method: http.MethodPost, endPoint: "/api/v1/apps/TESTID/users/USERID", expected: "app-id"},
{method: http.MethodDelete, endPoint: "/api/v1/apps/TESTID/users/USERID", expected: "app-id"},
{method: http.MethodGet, endPoint: "/api/v1/apps/TESTID/groups", expected: "app-id"},
{method: http.MethodPut, endPoint: "/api/v1/apps/TESTID/groups/GROUPID", expected: "app-id"},
{method: http.MethodDelete, endPoint: "/api/v1/apps/TESTID/groups/GROUPID", expected: "app-id"},
{method: http.MethodPost, endPoint: "/api/v1/apps/TESTID/logo", expected: "app-id"},

// 2. starts with /api/v1/apps
{method: http.MethodGet, endPoint: "/api/v1/apps", expected: "apps"},

// 3. [GET|PUT|DELETE] /api/v1/groups/${id}
{method: http.MethodGet, endPoint: "/api/v1/groups/TESTID", expected: "group-id"},
{method: http.MethodPut, endPoint: "/api/v1/groups/TESTID", expected: "group-id"},
{method: http.MethodDelete, endPoint: "/api/v1/groups/TESTID", expected: "group-id"},

// 4. starts with /api/v1/groups
{method: http.MethodGet, endPoint: "/api/v1/groups", expected: "groups"},
{method: http.MethodGet, endPoint: "/api/v1/groups/TESTID/apps", expected: "groups"},

// 5. GET /api/v1/users/me
{method: http.MethodGet, endPoint: "/api/v1/users/me", expected: "user-me"},

// 6. [POST|PUT|DELETE] /api/v1/users/${id}
{method: http.MethodPost, endPoint: "/api/v1/users/TESTID", expected: "user-id"},
{method: http.MethodPut, endPoint: "/api/v1/users/TESTID", expected: "user-id"},
{method: http.MethodDelete, endPoint: "/api/v1/users/TESTID", expected: "user-id"},

// 7. [GET] /api/v1/users/${idOrLogin}
{method: http.MethodGet, endPoint: "/api/v1/users/TESTID", expected: "user-id-get"},

// 8. starts with /api/v1/users
{method: http.MethodGet, endPoint: "/api/v1/users", expected: "users"},
{method: http.MethodGet, endPoint: "/api/v1/users/TESTID/devices", expected: "users"},

// 9. GET /api/v1/logs
{method: http.MethodGet, endPoint: "/api/v1/logs", expected: "logs"},

// 10. GET /api/v1/events
{method: http.MethodGet, endPoint: "/api/v1/events", expected: "events"},

// 11. GET /oauth2/v1/clients
{method: http.MethodGet, endPoint: "/oauth2/v1/clients", expected: "clients"},

// 12. GET /api/v1/certificateAuthorities
{method: http.MethodGet, endPoint: "/api/v1/certificateAuthorities", expected: "cas-id"},

// 13. GET /api/v1/devices
{method: http.MethodGet, endPoint: "/api/v1/devices", expected: "devices"},

// 14. GET /api/v1
{method: http.MethodGet, endPoint: "/api/v1/authorizationServers", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/authorizationServers/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/behaviors", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/behaviors/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/domains", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/domains/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/idps", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/idps/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/internal", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/internal/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/mappings", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/mappings/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/meta", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/meta/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/org", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/org/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/policies", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/policies/TESTID", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/templates", expected: "other"},
{method: http.MethodGet, endPoint: "/api/v1/templates/TESTID", expected: "other"},
}

amu, err := NewAPIMutex(100)
if err != nil {
t.Fatalf("api mutex constructor had error %+v", err)
}
for _, tc := range tests {
// test that private normalizedKey function is operating correctly
key := amu.normalizeKey(tc.method, tc.endPoint)
if key != tc.expected {
t.Fatalf("got %q, expected %q for method: %q, endPoint: %q", key, tc.expected, tc.method, tc.endPoint)
}
}
}

func minRemaining(remaining []int) int {
var result int
first := true
Expand Down
Loading