By Apoorv Khandelwal and Peter Curtin
Automatically distribute PyTorch functions onto multiple machines or GPUs
pip install torchrunx
Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0
Shared filesystem & SSH access if using multiple machines
Here's a simple example where we distribute distributed_function
to two hosts (with 2 GPUs each):
def train_model(model, dataset):
trained_model = train(model, dataset)
if int(os.environ["RANK"]) == 0:
torch.save(learned_model, 'model.pt')
return 'model.pt'
return None
import torchrunx as trx
model_path = trx.launch(
func=train_model,
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
hostnames=["localhost", "other_node"],
workers_per_host=2
)["localhost"][0] # return from rank 0 (first worker on "localhost")
torchrun
is a hammer. torchrunx
is a chisel.
Whether you have 1 GPU, 8 GPUs, or 8 machines:
Convenience:
- If you don't want to set up
dist.init_process_group
yourself - If you want to run
python myscript.py
instead oftorchrun myscript.py
- If you don't want to manually SSH and run
torchrun --master-ip --master-port ...
on every machine (and if you don't want to babysit these machines for hanging failures)
Robustness:
- If you want to run a complex, modular workflow in one script
- no worries about memory leaks or OS failures
- don't parallelize your entire script: just the functions you want
Features:
- Our launch utility is super Pythonic
- If you want to run distributed PyTorch functions from Python Notebooks.
- Automatic integration with SLURM
Why not?
- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
We could also launch multiple functions, with different GPUs:
def train_model(model, dataset):
trained_model = train(model, dataset)
if int(os.environ["RANK"]) == 0:
torch.save(learned_model, 'model.pt')
return 'model.pt'
return None
def test_model(model_path, test_dataset):
model = torch.load(model_path)
accuracy = inference(model, test_dataset)
return accuracy
import torchrunx as trx
model_path = trx.launch(
func=train_model,
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
hostnames=["localhost", "other_node"],
workers_per_host=2
)["localhost"][0] # return from rank 0 (first worker on "localhost")
accuracy = trx.launch(
func=test_model,
func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
hostnames=["localhost"],
workers_per_host=1
)["localhost"][0]
print(f'Accuracy: {accuracy}')