Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SSO/Spot] Fix spot controller failed to launch spot cluster when using SSO #1817

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,7 +2469,11 @@ def validate_schema(obj, schema, err_msg_prefix=''):


def check_public_cloud_enabled():
"""Checks if any of the public clouds is enabled."""
"""Checks if any of the public clouds is enabled.

Exceptions:
exceptions.NoCloudAccessError: if no public cloud is enabled.
"""

def _no_public_cloud():
enabled_clouds = global_user_state.get_enabled_clouds()
Expand All @@ -2483,7 +2487,7 @@ def _no_public_cloud():
sky_check.check(quiet=True)
if _no_public_cloud():
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
raise exceptions.NoCloudAccessError(
'Cloud access is not set up. Run: '
f'{colorama.Style.BRIGHT}sky check{colorama.Style.RESET_ALL}')

Expand Down
2 changes: 1 addition & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def _launch_with_confirm(
# Show the optimize log before the prompt if the cluster does not exist.
try:
backend_utils.check_public_cloud_enabled()
except RuntimeError as e:
except exceptions.NoCloudAccessError as e:
# Catch the exception where the public cloud is not enabled, and
# only print the error message without the error type.
click.secho(e, fg='yellow')
Expand Down
63 changes: 50 additions & 13 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Amazon Web Services."""
import enum
import functools
import json
import os
Expand Down Expand Up @@ -40,6 +41,19 @@
DEFAULT_AMI_GB = 45


class AWSIdentityType(enum.Enum):
"""AWS identity type.

The account type is determined by the current user identity,
based on `aws configure list`. We will check the existence of
the value in the output of `aws configure list` to determine
the account type.
"""
SSO = 'sso'
IAM_ROLE = 'iam-role'
STATIC = 'static'


@clouds.CLOUD_REGISTRY.register
class AWS(clouds.Cloud):
"""Amazon Web Services."""
Expand Down Expand Up @@ -405,19 +419,27 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
static_credential_exists = os.path.isfile(
os.path.expanduser('~/.aws/credentials'))
hints = None
if cls._is_current_identity_sso():
hints = 'AWS SSO is set. '
identity_type = cls._current_identity_type()
single_cloud_hint = (
' It will work if you use AWS only, but will cause problems '
'if you want to use multiple clouds. To set up static credentials, '
'try: aws configure')
if identity_type == AWSIdentityType.SSO:
hints = 'AWS SSO is set.'
if static_credential_exists:
hints += (
' To ensure multiple clouds work correctly, please use SkyPilot '
'with static credentials (e.g., ~/.aws/credentials) by unsetting '
'the AWS_PROFILE environment variable.')
else:
hints += (
' It will work if you use AWS only, but will cause problems '
'if you want to use multiple clouds. To set up static credentials, '
'try: aws configure')

hints += single_cloud_hint
elif identity_type == AWSIdentityType.IAM_ROLE:
# When using an IAM role, the credentials may not exist in the
# ~/.aws/credentials file. So we don't check for the existence of the
# file. This will happen when the user is on a VM (or spot-controller)
# created by an SSO account, i.e. the VM will be assigned the IAM
# role: skypilot-v1.
hints = f'AWS IAM role is set.{single_cloud_hint}'
else:
# This file is required because it is required by the VMs launched on
# other clouds to access private s3 buckets and resources like EC2.
Expand All @@ -433,15 +455,30 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
return True, hints

@classmethod
def _is_current_identity_sso(cls) -> bool:
def _current_identity_type(cls) -> Optional[AWSIdentityType]:
proc = subprocess.run('aws configure list',
shell=True,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if proc.returncode != 0:
return False
return 'sso' in proc.stdout.decode().split()
return None
# We determine the identity type by looking at the output of
# `aws configure list`. The output looks like:
# Name Value Type Location
# ---- ----- ---- --------
# profile <not set> None None
# access_key * <not set> sso None
# secret_key * <not set> sso None
# region <not set> None None
# We try to determine the identity type by looking for the
# string "sso"/"iam-role" in the output, i.e. the "Type" column.
if AWSIdentityType.SSO.value in proc.stdout.decode():
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
return AWSIdentityType.SSO
elif AWSIdentityType.IAM_ROLE.value in proc.stdout.decode():
return AWSIdentityType.IAM_ROLE
else:
return AWSIdentityType.STATIC

@classmethod
def get_current_user_identity(cls) -> Optional[List[str]]:
Expand Down Expand Up @@ -544,8 +581,8 @@ def get_current_user_identity(cls) -> Optional[List[str]]:

def get_credential_file_mounts(self) -> Dict[str, str]:
# TODO(skypilot): ~/.aws/credentials is required for users using multiple clouds.
# If this file does not exist, users can launch on AWS via AWS SSO and assign
# IAM role to the cluster.
# If this file does not exist, users can launch on AWS via AWS SSO or assumed IAM
# role (only when the user is on an AWS cluster) and assign IAM role to the cluster.
# However, if users launch clusters in a non-AWS cloud, those clusters do not
# understand AWS IAM role so will not be able to access private AWS EC2 resources
# and S3 buckets.
Expand All @@ -559,7 +596,7 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
# to define a mechanism to find out the cloud provider of the cluster to be
# launched in this function and make sure the cluster will not be used for
# launching clusters in other clouds, e.g. spot controller.
if self._is_current_identity_sso():
if self._current_identity_type() != AWSIdentityType.STATIC:
return {}
return {
f'~/.aws/{filename}': f'~/.aws/{filename}'
Expand Down
5 changes: 5 additions & 0 deletions sky/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,8 @@ class CloudUserIdentityError(Exception):
class ClusterOwnerIdentityMismatchError(Exception):
"""The cluster's owner identity does not match the current user identity."""
pass


class NoCloudAccessError(Exception):
"""Raised when all clouds are disabled."""
pass
1 change: 1 addition & 0 deletions sky/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def launch(
our pre-checks (e.g., cluster name invalid) or a region/zone
throwing resource unavailability.
exceptions.CommandError: any ssh command error.
exceptions.NoCloudAccessError: if all clouds are disabled.
Other exceptions may be raised depending on the backend.
"""
entrypoint = task
Expand Down
1 change: 1 addition & 0 deletions sky/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def optimize(
Raises:
exceptions.ResourcesUnavailableError: if no resources are available
for a task.
exceptions.NoCloudAccessError: if no public clouds are enabled.
"""
# This function is effectful: mutates every node in 'dag' by setting
# node.best_resources if it is None.
Expand Down
3 changes: 2 additions & 1 deletion sky/spot/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
detach_run=True,
_is_launched_by_spot_controller=True)
logger.info('Spot cluster launched.')
except exceptions.InvalidClusterNameError as e:
except (exceptions.InvalidClusterNameError,
Michaelvll marked this conversation as resolved.
Show resolved Hide resolved
exceptions.NoCloudAccessError) as e:
logger.error('Failure happened before provisioning. '
f'{common_utils.format_exception(e)}')
if raise_on_failure:
Expand Down