Skip to content

Launch PyTorch functions onto multiple machines & GPUs

License

Notifications You must be signed in to change notification settings

apoorvkh/torchrunx

Repository files navigation

torchrunx 🔥

PyPI - Python Version PyPI - Version Tests Docs GitHub License

By Apoorv Khandelwal and Peter Curtin

Automatically distribute PyTorch functions onto multiple machines or GPUs

Installation

pip install torchrunx

Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0

Shared filesystem & SSH access if using multiple machines

Minimal example

Here's a simple example where we distribute distributed_function to two hosts (with 2 GPUs each):

def train_model(model, dataset):
    trained_model = train(model, dataset)

    if int(os.environ["RANK"]) == 0:
        torch.save(learned_model, 'model.pt')
        return 'model.pt'

    return None
import torchrunx as trx

model_path = trx.launch(
    func=train_model,
    func_kwargs={'model': my_model, 'training_dataset': mnist_train},
    hostnames=["localhost", "other_node"],
    workers_per_host=2
)["localhost"][0]  # return from rank 0 (first worker on "localhost")

Why should I use this?

torchrun is a hammer. torchrunx is a chisel.

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Convenience:

  • If you don't want to set up dist.init_process_group yourself
  • If you want to run python myscript.py instead of torchrun myscript.py
  • If you don't want to manually SSH and run torchrun --master-ip --master-port ... on every machine (and if you don't want to babysit these machines for hanging failures)

Robustness:

  • If you want to run a complex, modular workflow in one script
    • no worries about memory leaks or OS failures
    • don't parallelize your entire script: just the functions you want

Features:

  • Our launch utility is super Pythonic
  • If you want to run distributed PyTorch functions from Python Notebooks.
  • Automatic integration with SLURM

Why not?

  • We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.

More complicated example

We could also launch multiple functions, with different GPUs:

def train_model(model, dataset):
    trained_model = train(model, dataset)

    if int(os.environ["RANK"]) == 0:
        torch.save(learned_model, 'model.pt')
        return 'model.pt'

    return None

def test_model(model_path, test_dataset):
    model = torch.load(model_path)
    accuracy = inference(model, test_dataset)
    return accuracy
import torchrunx as trx

model_path = trx.launch(
    func=train_model,
    func_kwargs={'model': my_model, 'training_dataset': mnist_train},
    hostnames=["localhost", "other_node"],
    workers_per_host=2
)["localhost"][0]  # return from rank 0 (first worker on "localhost")



accuracy = trx.launch(
    func=test_model,
    func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
    hostnames=["localhost"],
    workers_per_host=1
)["localhost"][0]

print(f'Accuracy: {accuracy}')