Skip to content

Commit

Permalink
[Core] Add new token class and protocols (#36565)
Browse files Browse the repository at this point in the history
Add new AccessTokenInfo class and supporting protocols
AsyncSupportsTokenInfo and SupportsTokenInfo.

Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
pvaneck authored Sep 5, 2024
1 parent 5cf3b46 commit 526dca6
Show file tree
Hide file tree
Showing 16 changed files with 499 additions and 77 deletions.
10 changes: 7 additions & 3 deletions sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

### Features Added

- `AccessToken` now has an optional `refresh_on` attribute that can be used to specify when the token should be refreshed. #36183
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check the `refresh_on` attribute when determining if a token request should be made.
- Added `azure.core.AzureClouds` enum to represent the different Azure clouds.
- Added azure.core.AzureClouds enum to represent the different Azure clouds.
- Added two new credential protocol classes, `SupportsTokenInfo` and `AsyncSupportsTokenInfo`, to offer more extensibility in supporting various token acquisition scenarios. #36565
- Each new protocol class defines a `get_token_info` method that returns an `AccessTokenInfo` object.
- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. #36565
- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. #36565
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now first check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. Otherwise, the `get_token` method is used. #36565
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.

### Breaking Changes

Expand Down
85 changes: 79 additions & 6 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,64 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, TypedDict, Union, ContextManager
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
"""Represents an OAuth access token."""

token: str
"""The token string."""
expires_on: int
refresh_on: Optional[int] = None
"""The token's expiration time in Unix time."""


AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
class AccessTokenInfo:
"""Information about an OAuth access token.
This class is an alternative to `AccessToken` which provides additional information about the token.
:param str token: The token string.
:param int expires_on: The token's expiration time in Unix time.
:keyword str token_type: The type of access token. Defaults to 'Bearer'.
:keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively
refreshed. Optional.
"""

token: str
"""The token string."""
expires_on: int
"""The token's expiration time in Unix time."""
token_type: str
"""The type of access token."""
refresh_on: Optional[int]
"""Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional."""

def __init__(
self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None
) -> None:
self.token = token
self.expires_on = expires_on
self.token_type = token_type
self.refresh_on = refresh_on

def __repr__(self) -> str:
return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format(
self.token, self.expires_on, self.token_type, self.refresh_on
)


class TokenRequestOptions(TypedDict, total=False):
"""Options to use for access token requests. All parameters are optional."""

claims: str
"""Additional claims required in the token, such as those returned in a resource provider's claims
challenge following an authorization failure."""
tenant_id: str
"""The tenant ID to include in the token request."""
enable_cae: bool
"""Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token."""


@runtime_checkable
Expand All @@ -30,7 +73,7 @@ def get_token(
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
enable_cae: bool = False,
**kwargs: Any
**kwargs: Any,
) -> AccessToken:
"""Request an access token for `scopes`.
Expand All @@ -48,6 +91,32 @@ def get_token(
...


@runtime_checkable
class SupportsTokenInfo(Protocol, ContextManager["SupportsTokenInfo"]):
"""Protocol for classes able to provide OAuth access tokens with additional properties."""

def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
"""Request an access token for `scopes`.
This is an alternative to `get_token` to enable certain scenarios that require additional properties
on the token.
:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions
:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing information about the token.
"""
...

def close(self) -> None:
pass


TokenProvider = Union[TokenCredential, SupportsTokenInfo]


class AzureNamedKey(NamedTuple):
"""Represents a name and key pair."""

Expand All @@ -59,8 +128,12 @@ class AzureNamedKey(NamedTuple):
"AzureKeyCredential",
"AzureSasCredential",
"AccessToken",
"AccessTokenInfo",
"SupportsTokenInfo",
"AzureNamedKeyCredential",
"TokenCredential",
"TokenRequestOptions",
"TokenProvider",
]


Expand Down
42 changes: 40 additions & 2 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
# ------------------------------------
from __future__ import annotations
from types import TracebackType
from typing import Any, Optional, AsyncContextManager, Type
from typing import Any, Optional, AsyncContextManager, Type, Union
from typing_extensions import Protocol, runtime_checkable
from .credentials import AccessToken as _AccessToken
from .credentials import (
AccessToken as _AccessToken,
AccessTokenInfo as _AccessTokenInfo,
TokenRequestOptions as _TokenRequestOptions,
)


@runtime_checkable
Expand Down Expand Up @@ -46,3 +50,37 @@ async def __aexit__(
traceback: Optional[TracebackType] = None,
) -> None:
pass


@runtime_checkable
class AsyncSupportsTokenInfo(Protocol, AsyncContextManager["AsyncSupportsTokenInfo"]):
"""Protocol for classes able to provide OAuth access tokens with additional properties."""

async def get_token_info(self, *scopes: str, options: Optional[_TokenRequestOptions] = None) -> _AccessTokenInfo:
"""Request an access token for `scopes`.
This is an alternative to `get_token` to enable certain scenarios that require additional properties
on the token.
:param str scopes: The type of access needed.
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
:paramtype options: TokenRequestOptions
:rtype: AccessTokenInfo
:return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time.
"""
...

async def close(self) -> None:
pass

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass


AsyncTokenProvider = Union[AsyncTokenCredential, AsyncSupportsTokenInfo]
54 changes: 35 additions & 19 deletions sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# license information.
# -------------------------------------------------------------------------
import time
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpResponse, HttpRequest
Expand All @@ -15,7 +16,7 @@
# pylint:disable=unused-import
from azure.core.credentials import (
AccessToken,
TokenCredential,
AccessTokenInfo,
AzureKeyCredential,
AzureSasCredential,
)
Expand All @@ -29,17 +30,17 @@ class _BearerTokenCredentialPolicyBase:
"""Base class for a Bearer Token Credential Policy.
:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.TokenProvider
:param str scopes: Lets you specify the type of access needed.
:keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
tokens. Defaults to False.
"""

def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs: Any) -> None:
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token: Optional["AccessToken"] = None
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
self._enable_cae: bool = kwargs.get("enable_cae", False)

@staticmethod
Expand Down Expand Up @@ -70,11 +71,29 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@property
def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

def _request_token(self, *scopes: str, **kwargs: Any) -> None:
"""Request a new token from the credential.
This will call the credential's appropriate method to get a token and store it in the policy.
:param str scopes: The type of access needed.
"""
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)

if hasattr(self._credential, "get_token_info"):
options: TokenRequestOptions = {}
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
for key in list(kwargs.keys()):
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
options[key] = kwargs.pop(key) # type: ignore[literal-required]

self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
else:
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand All @@ -98,11 +117,9 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
self._enforce_https(request)

if self._token is None or self._need_new_token:
if self._enable_cae:
self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae)
else:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token)
self._request_token(*self._scopes)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -113,10 +130,9 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
:param ~azure.core.pipeline.PipelineRequest request: the request
:param str scopes: required scopes of authentication
"""
if self._enable_cae:
kwargs.setdefault("enable_cae", self._enable_cae)
self._token = self._credential.get_token(*scopes, **kwargs)
self._update_headers(request.http_request.headers, self._token.token)
self._request_token(*scopes, **kwargs)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)

def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
"""Authorize request with a bearer token and send it to the next policy
Expand Down
Loading

0 comments on commit 526dca6

Please sign in to comment.