From 1b4f3d85f9c9fc228526bb0017a59302a3bc3a7c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 14 Mar 2023 14:09:40 -0500 Subject: [PATCH] [azure]: Added extra_vm_options --- dask_cloudprovider/azure/azurevm.py | 18 ++++++++++++++- dask_cloudprovider/cloudprovider.yaml | 1 + doc/source/azure.rst | 33 +++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/dask_cloudprovider/azure/azurevm.py b/dask_cloudprovider/azure/azurevm.py index a2accbb0..521e7901 100644 --- a/dask_cloudprovider/azure/azurevm.py +++ b/dask_cloudprovider/azure/azurevm.py @@ -48,6 +48,7 @@ def __init__( extra_bootstrap=None, auto_shutdown: bool = None, marketplace_plan: dict = {}, + extra_vm_options: Optional[dict] = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -71,6 +72,7 @@ def __init__( self.auto_shutdown = auto_shutdown self.env_vars = env_vars self.marketplace_plan = marketplace_plan + self.extra_vm_options = extra_vm_options or {} async def create_vm(self): [subnet_info, *_] = await self.cluster.call_async( @@ -179,6 +181,13 @@ async def create_vm(self): vm_parameters["storage_profile"]["image_reference"]["version"] = "latest" self.cluster._log("Using Marketplace VM image with a Plan") + repeated = self.extra_vm_options.keys() & vm_parameters.keys() + if repeated: + raise TypeError( + f"Parameters are passed in both 'extra_vm_options' and as regular parameters: {repeated}" + ) + vm_parameters = {**self.extra_vm_options, **vm_parameters} + self.cluster._log("Creating VM") if self.cluster.debug: self.cluster._log( @@ -344,6 +353,9 @@ class AzureVMCluster(VMCluster): The ID of the Azure Subscription to create the virtual machines in. If not specified, then dask-cloudprovider will attempt to use the configured default for the Azure CLI. List your subscriptions with ``az account list``. + extra_vm_options: dict[str, Any]: + Additional arguments to provide to Azure's ``VirtualMachinesOperations.begin_create_or_update`` + when creating the scheduler and worker VMs. Examples -------- @@ -472,6 +484,7 @@ def __init__( debug: bool = False, marketplace_plan: dict = {}, subscription_id: Optional[str] = None, + extra_vm_options: Optional[dict] = None, **kwargs, ): self.config = ClusterConfig(dask.config.get("cloudprovider.azure", {})) @@ -550,7 +563,9 @@ def __init__( """To create a virtual machine from Marketplace image or a custom image sourced from a Marketplace image with a plan, all 3 fields 'name', 'publisher' and 'product' must be passed.""" ) - + self.extra_vm_options = extra_vm_options or self.config.get( + "azurevm.extra_vm_options" + ) self.options = { "cluster": self, "config": self.config, @@ -563,6 +578,7 @@ def __init__( "auto_shutdown": self.auto_shutdown, "docker_image": self.docker_image, "marketplace_plan": self.marketplace_plan, + "extra_vm_options": self.extra_vm_options, } self.scheduler_options = { "vm_size": self.scheduler_vm_size, diff --git a/dask_cloudprovider/cloudprovider.yaml b/dask_cloudprovider/cloudprovider.yaml index 3e7c5ce7..f9d3ae63 100755 --- a/dask_cloudprovider/cloudprovider.yaml +++ b/dask_cloudprovider/cloudprovider.yaml @@ -84,6 +84,7 @@ cloudprovider: # name: "ngc-base-version-21-02-2" # publisher: "nvidia" # product: "ngc_azure_17_11" + extra_options: {} # Additional options to provide when creating the VMs. digitalocean: token: null # API token for interacting with the Digital Ocean API diff --git a/doc/source/azure.rst b/doc/source/azure.rst index f633292c..a1334a3d 100644 --- a/doc/source/azure.rst +++ b/doc/source/azure.rst @@ -93,6 +93,39 @@ or specific IP. Again take note of this security group name for later. +Extra options +^^^^^^^^^^^^^ + +To further customize the VMs created, you can provide ``extra_vm_options`` to :class:`AzureVMCluster`. For example, to set the identity +of the virtual machines to a (previously created) user assigned identity, create an ``azure.mgmt.compute.models.VirtualMachineIdentity`` + +.. code-block:: python + + >>> import os + >>> import azure.identity + >>> import dask_cloudprovider.azure + >>> import azure.mgmt.compute.models + + >>> subscription_id = os.environ["DASK_CLOUDPROVIDER__AZURE__SUBSCRIPTION_ID"] + >>> rg_name = os.environ["DASK_CLOUDPROVIDER__AZURE__RESOURCE_GROUP"] + >>> identity_name = "dask-cloudprovider-identity" + >>> v = azure.mgmt.compute.models.UserAssignedIdentitiesValue() + >>> user_assigned_identities = { + ... f"/subscriptions/{subscription_id}/resourcegroups/{rg_name}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{identity_name}": v + ... } + >>> identity = azure.mgmt.compute.models.VirtualMachineIdentity( + ... type="UserAssigned", + ... user_assigned_identities=user_assigned_identities + ... ) + + +And then provide that to :class:`AzureVMCluster` + +.. code-block:: python + + >>> cluster = dask_cloudprovider.azure.AzureVMCluster(extra_vm_options={"identity": identity.as_dict()}) + >>> cluster.scale(1) + Dask Configuration ^^^^^^^^^^^^^^^^^^