commit df664e0fa9d360f20051638f74fbaceb0814b144
parent 9030cecdf06da99290340ca429d2c1f94b96f9f4
Author: Phillip Self <[email protected]>
Date: Sun, 12 Mar 2023 21:27:27 -0500
Add Apple Metal support for faster training on M1/M2 (#122)
Co-authored-by: Phillip Self <[email protected]>
Diffstat:
1 file changed, 2 insertions(+), 0 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -234,6 +234,8 @@ def _get_configs(
}
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
+ elif torch.backends.mps.is_available():
+ device_config = {"accelerator": "mps", "devices": 1}
else:
print("WARNING: No GPU was found. Training will be very slow!")
device_config = {}