Skip to content

Commit

Permalink
add SpecifcAzureCredential to DefaultAzureCredential
Browse files Browse the repository at this point in the history
  • Loading branch information
cataggar committed Dec 29, 2023
1 parent 9e9318b commit 96a78ed
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ pub struct AppServiceManagedIdentityCredential {
}

impl AppServiceManagedIdentityCredential {
pub fn create(options: TokenCredentialOptions) -> azure_core::Result<Self> {
pub fn create(options: impl Into<TokenCredentialOptions>) -> azure_core::Result<Self> {
let options = options.into();
let env = options.env();
let endpoint = Url::parse(&env.var(ENDPOINT_ENV)?)?;
Ok(Self {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub struct ClientCertificateCredential {
impl ClientCertificateCredential {
/// Create a new `ClientCertificateCredential`
pub fn new<C, P>(
options: TokenCredentialOptions,
options: impl Into<TokenCredentialOptions>,
tenant_id: String,
client_id: String,
client_certificate: C,
Expand Down
92 changes: 89 additions & 3 deletions sdk/identity/src/token_credentials/default_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Duration;

use crate::{
timeout::TimeoutExt, token_credentials::cache::TokenCache, AppServiceManagedIdentityCredential,
AzureCliCredential, EnvironmentCredential, TokenCredentialOptions,
AzureCliCredential, EnvironmentCredential, SpecificAzureCredential, TokenCredentialOptions,
VirtualMachineManagedIdentityCredential,
};
use azure_core::{
Expand All @@ -13,6 +13,7 @@ use azure_core::{
/// Provides a mechanism of selectively disabling credentials used for a `DefaultAzureCredential` instance
pub struct DefaultAzureCredentialBuilder {
options: TokenCredentialOptions,
include_specific_credential: bool,
include_environment_credential: bool,
include_app_service_managed_identity_credential: bool,
include_virtual_machine_managed_identity_credential: bool,
Expand All @@ -23,6 +24,7 @@ impl Default for DefaultAzureCredentialBuilder {
fn default() -> Self {
Self {
options: TokenCredentialOptions::default(),
include_specific_credential: true,
include_environment_credential: true,
include_app_service_managed_identity_credential: true,
// Unable to quickly detect if running in Azure VM, so it is disabled by default.
Expand All @@ -38,8 +40,14 @@ impl DefaultAzureCredentialBuilder {
Self::default()
}

pub fn with_options(&mut self, options: TokenCredentialOptions) -> &mut Self {
self.options = options;
pub fn with_options(&mut self, options: impl Into<TokenCredentialOptions>) -> &mut Self {
self.options = options.into();
self
}

/// Exclude specific credential
pub fn exclude_specific_credential(&mut self) -> &mut Self {
self.include_specific_credential = false;
self
}

Expand Down Expand Up @@ -83,6 +91,9 @@ impl DefaultAzureCredentialBuilder {
/// Get a list of the credential types to include.
fn included(&self) -> Vec<DefaultAzureCredentialType> {
let mut sources = Vec::new();
if self.include_specific_credential {
sources.push(DefaultAzureCredentialType::Specific);
}
if self.include_environment_credential {
sources.push(DefaultAzureCredentialType::Environment);
}
Expand All @@ -102,9 +113,18 @@ impl DefaultAzureCredentialBuilder {
&self,
included: &Vec<DefaultAzureCredentialType>,
) -> Vec<DefaultAzureCredentialEnum> {
// If specific credential is included, try to create it.
// Use only the specific credential if it is created successfully.
if self.include_specific_credential {
if let Ok(credential) = SpecificAzureCredential::create(self.options.clone()) {
return vec![DefaultAzureCredentialEnum::Specific(credential)];
}
}

let mut sources = Vec::<DefaultAzureCredentialEnum>::with_capacity(included.len());
for source in included {
match source {
DefaultAzureCredentialType::Specific => {}
DefaultAzureCredentialType::Environment => {
if let Ok(credential) = EnvironmentCredential::create(self.options.clone()) {
sources.push(DefaultAzureCredentialEnum::Environment(credential));
Expand Down Expand Up @@ -143,6 +163,7 @@ impl DefaultAzureCredentialBuilder {
/// Types that may be enabled for use by `DefaultAzureCredential`.
#[derive(Debug, PartialEq)]
enum DefaultAzureCredentialType {
Specific,
Environment,
AppService,
VirtualMachine,
Expand All @@ -152,6 +173,8 @@ enum DefaultAzureCredentialType {
/// Types of `TokenCredential` supported by `DefaultAzureCredential`
#[derive(Debug)]
pub enum DefaultAzureCredentialEnum {
/// A `TokenCredential` instance specified with an `AZURE_CREDENTIAL_TYPE` environment variable.
Specific(SpecificAzureCredential),
/// `TokenCredential` from environment variable.
Environment(EnvironmentCredential),
/// `TokenCredential` from managed identity that has been assigned to an App Service.
Expand All @@ -167,6 +190,10 @@ pub enum DefaultAzureCredentialEnum {
impl TokenCredential for DefaultAzureCredentialEnum {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
match self {
DefaultAzureCredentialEnum::Specific(credential) => credential
.get_token(scopes)
.await
.context(ErrorKind::Credential, "error getting specific credential"),
DefaultAzureCredentialEnum::Environment(credential) => {
credential.get_token(scopes).await.context(
ErrorKind::Credential,
Expand Down Expand Up @@ -206,6 +233,7 @@ impl TokenCredential for DefaultAzureCredentialEnum {
/// Clear the credential's cache.
async fn clear_cache(&self) -> azure_core::Result<()> {
match self {
DefaultAzureCredentialEnum::Specific(credential) => credential.clear_cache().await,
DefaultAzureCredentialEnum::Environment(credential) => credential.clear_cache().await,
DefaultAzureCredentialEnum::AppService(credential) => credential.clear_cache().await,
DefaultAzureCredentialEnum::VirtualMachine(credential) => {
Expand Down Expand Up @@ -297,31 +325,39 @@ fn format_aggregate_error(errors: &[Error]) -> String {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
env::{EnvEnum, MemEnv},
SpecificAzureCredentialEnum,
};

#[test]
fn test_builder_included_credential_flags() {
let builder = DefaultAzureCredentialBuilder::new();
assert!(builder.include_specific_credential);
assert!(builder.include_azure_cli_credential);
assert!(builder.include_environment_credential);
assert!(builder.include_app_service_managed_identity_credential);
assert!(!builder.include_virtual_machine_managed_identity_credential);

let mut builder = DefaultAzureCredentialBuilder::new();
builder.exclude_azure_cli_credential();
assert!(builder.include_specific_credential);
assert!(!builder.include_azure_cli_credential);
assert!(builder.include_environment_credential);
assert!(builder.include_app_service_managed_identity_credential);
assert!(!builder.include_virtual_machine_managed_identity_credential);

let mut builder = DefaultAzureCredentialBuilder::new();
builder.exclude_environment_credential();
assert!(builder.include_specific_credential);
assert!(builder.include_azure_cli_credential);
assert!(!builder.include_environment_credential);
assert!(builder.include_app_service_managed_identity_credential);
assert!(!builder.include_virtual_machine_managed_identity_credential);

let mut builder = DefaultAzureCredentialBuilder::new();
builder.exclude_managed_identity_credential();
assert!(builder.include_specific_credential);
assert!(builder.include_azure_cli_credential);
assert!(builder.include_environment_credential);
assert!(!builder.include_app_service_managed_identity_credential);
Expand All @@ -335,6 +371,7 @@ mod tests {
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Specific,
DefaultAzureCredentialType::Environment,
DefaultAzureCredentialType::AppService,
DefaultAzureCredentialType::AzureCli,
Expand All @@ -350,6 +387,7 @@ mod tests {
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Specific,
DefaultAzureCredentialType::Environment,
DefaultAzureCredentialType::AppService,
DefaultAzureCredentialType::VirtualMachine,
Expand All @@ -366,6 +404,7 @@ mod tests {
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Specific,
DefaultAzureCredentialType::AppService,
DefaultAzureCredentialType::AzureCli,
]
Expand All @@ -380,6 +419,7 @@ mod tests {
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Specific,
DefaultAzureCredentialType::Environment,
DefaultAzureCredentialType::AppService,
]
Expand All @@ -391,12 +431,58 @@ mod tests {
fn test_exclude_managed_identity_credential() {
let mut builder = DefaultAzureCredentialBuilder::new();
builder.exclude_managed_identity_credential();
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Specific,
DefaultAzureCredentialType::Environment,
DefaultAzureCredentialType::AzureCli,
]
);
}

/// test excluding specific credential
#[test]
fn test_exclude_specific_credential() {
let mut builder = DefaultAzureCredentialBuilder::new();
builder.exclude_specific_credential();
assert_eq!(
builder.included(),
vec![
DefaultAzureCredentialType::Environment,
DefaultAzureCredentialType::AppService,
DefaultAzureCredentialType::AzureCli,
]
);
}

// test specific environment credential
#[test]
fn test_specific_environment_credential() {
let env = EnvEnum::Mem(MemEnv::from(
&[
("AZURE_CREDENTIAL_TYPE", "environment"),
("AZURE_TENANT_ID", "1"),
("AZURE_CLIENT_ID", "2"),
("AZURE_CLIENT_SECRET", "3"),
][..],
));
let http_client = azure_core::new_noop_client();
let options = TokenCredentialOptions::new(
env,
http_client,
azure_core::authority_hosts::AZURE_PUBLIC_CLOUD.to_owned(),
);
let credential = DefaultAzureCredentialBuilder::new()
.with_options(options)
.build();
assert_eq!(credential.sources.len(), 1);
match &credential.sources[0] {
DefaultAzureCredentialEnum::Specific(credential) => match credential.source() {
SpecificAzureCredentialEnum::Environment(_credential) => {}
_ => panic!("expected environment credential"),
},
_ => panic!("expected specific credential"),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ pub struct EnvironmentCredential {
}

impl EnvironmentCredential {
pub fn create(options: TokenCredentialOptions) -> azure_core::Result<EnvironmentCredential> {
pub fn create(
options: impl Into<TokenCredentialOptions>,
) -> azure_core::Result<EnvironmentCredential> {
let options = options.into();
let env = options.env();
let tenant_id = env
.var(AZURE_TENANT_ID_ENV_KEY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ pub struct ImdsManagedIdentityCredential {

impl ImdsManagedIdentityCredential {
pub fn new(
options: TokenCredentialOptions,
options: impl Into<TokenCredentialOptions>,
endpoint: Url,
api_version: &str,
secret_header: HeaderName,
secret_env: &str,
id: ImdsId,
) -> Self {
let options = options.into();
Self {
http_client: options.http_client(),
endpoint,
Expand Down
16 changes: 10 additions & 6 deletions sdk/identity/src/token_credentials/specific_azure_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,14 @@ impl TokenCredential for SpecificAzureCredentialEnum {

#[derive(Debug)]
pub struct SpecificAzureCredential {
credential: SpecificAzureCredentialEnum,
source: SpecificAzureCredentialEnum,
}

impl SpecificAzureCredential {
pub fn create(options: TokenCredentialOptions) -> azure_core::Result<SpecificAzureCredential> {
let env = options.env();
let credential_type = env.var(AZURE_CREDENTIAL_TYPE)?;
let credential: SpecificAzureCredentialEnum = match credential_type.to_lowercase().as_str()
{
let source: SpecificAzureCredentialEnum = match credential_type.to_lowercase().as_str() {
azure_credential_types::ENVIRONMENT => EnvironmentCredential::create(options)
.map(SpecificAzureCredentialEnum::Environment)?,
azure_credential_types::APP_SERVICE => {
Expand All @@ -88,18 +87,23 @@ impl SpecificAzureCredential {
}))
}
};
Ok(Self { credential })
Ok(Self { source })
}

#[cfg(test)]
pub(crate) fn source(&self) -> &SpecificAzureCredentialEnum {
&self.source
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for SpecificAzureCredential {
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
self.credential.get_token(scopes).await
self.source.get_token(scopes).await
}

async fn clear_cache(&self) -> azure_core::Result<()> {
self.credential.clear_cache().await
self.source.clear_cache().await
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct VirtualMachineManagedIdentityCredential {
}

impl VirtualMachineManagedIdentityCredential {
pub fn new(options: TokenCredentialOptions) -> Self {
pub fn new(options: impl Into<TokenCredentialOptions>) -> Self {
let endpoint = Url::parse(ENDPOINT).unwrap(); // valid url constant
Self {
credential: ImdsManagedIdentityCredential::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ pub struct WorkloadIdentityCredential {
impl WorkloadIdentityCredential {
/// Create a new `WorkloadIdentityCredential`
pub fn new<T>(
options: TokenCredentialOptions,
options: impl Into<TokenCredentialOptions>,
tenant_id: String,
client_id: String,
token: T,
) -> Self
where
T: Into<Secret>,
{
let options = options.into();
Self {
http_client: options.http_client().clone(),
authority_host: options.authority_host().clone(),
Expand Down

0 comments on commit 96a78ed

Please sign in to comment.