pytorch cuda setup for nvidia gh200 (arm64)
on this page
| problem: |
|
| symptom: | torch.cuda.is_available() returns False, version shows 2.9.0+cpu |
| cause: | PyPI only distributes CPU-only wheels for ARM64; CUDA wheels on separate index |
| solution: | configure uv to use pytorch’s cuda index with triton dependency override |
| key config: | add |
| first official support: | pytorch 2.7 (late 2024) - arm64 cuda wheels still relatively new |
the problem
what goes wrong
when you install pytorch on an arm64/aarch64 system using standard methods:
pip install torch
# or
uv pip install torch you get a cpu-only version, even though:
- your system has an nvidia gh200 gpu (96gb or 144gb hbm3/hbm3e)
- cuda 12.8 is installed (
nvcc --versionworks) - nvidia drivers are working (
nvidia-smishows the gpu)
how to detect the issue
# check cuda availability
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
# ❌ CUDA available: False
# check pytorch version
python -c "import torch; print(torch.__version__)"
# ❌ 2.9.0+cpu (notice the "+cpu" suffix) the correct version should show 2.9.0+cu128 or 2.10.0.dev20251109+cu128.
why this happens
pytorch distributes different wheel packages for different platforms:
- pypi (default): only cpu-only wheels for arm64
- cuda-specific indices: host gpu-enabled wheels for arm64
reasons for this split:
- arm64 gpu support is recent (pytorch 2.7+, late 2024)
- cuda wheels are large (500mb+ with bundled cuda libraries)
- arm64 cuda users still a small minority
the solution
quick install with uv
add this configuration to your pyproject.toml:
[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true
[tool.uv]
override-dependencies = [
"pytorch-triton ; sys_platform == 'win32'", # make triton optional on arm64
]
[tool.uv.sources]
torch = { index = "pytorch-nightly-cu128" } then sync:
uv sync --reinstall-package torch --prerelease=allow stable pytorch 2.9
for production use, switch to stable 2.9:
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv]
override-dependencies = [
"pytorch-triton ; sys_platform == 'win32'",
]
[tool.uv.sources]
torch = { index = "pytorch-cu128" } install:
uv sync --reinstall-package torch verification
test your installation:
# test 1: check cuda availability
uv run python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
# ✅ CUDA available: True
# test 2: check version and device
uv run python -c "
import torch
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA version: {torch.version.cuda}')
print(f'Device count: {torch.cuda.device_count()}')
print(f'Device 0: {torch.cuda.get_device_name(0)}')
"
# ✅ PyTorch version: 2.10.0.dev20251109+cu128
# ✅ CUDA version: 12.8
# ✅ Device count: 1
# ✅ Device 0: NVIDIA GH200 480GB
# test 3: run gpu computation
uv run python -c "
import torch
x = torch.randn(1000, 1000).cuda()
y = x @ x.T
print(f'✓ GPU tensor operation successful')
print(f'Result shape: {y.shape}, device: {y.device}')
"
# ✅ Result shape: torch.Size([1000, 1000]), device: cuda:0 understanding the configuration
custom index definition
[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true components:
name: arbitrary identifier for the indexurl: pytorch’s nightly cuda 12.8 wheel repositoryexplicit = true: only use this index for explicitly configured packages (not transitive dependencies like numpy)
available indices:
| version | url | description |
|---|---|---|
| nightly cu128 | https://download.pytorch.org/whl/nightly/cu128 | latest features (2.10.0.dev) |
| stable 2.9 | https://download.pytorch.org/whl/cu128 | pytorch 2.9.0+cu128 |
| stable 2.8 | https://download.pytorch.org/whl/cu126 | pytorch 2.8.0+cu126 (cuda 12.6) |
triton dependency override
[tool.uv]
override-dependencies = [
"pytorch-triton ; sys_platform == 'win32'",
] what this does:
tells uv to only require pytorch-triton on windows, effectively making it optional on arm64 linux.
why needed:
- triton only provides x86_64 wheels (no arm64 wheels available)
- pytorch depends on triton for
torch.compile()optimizations - update (2025): triton 3.5.0 now has arm64 wheels on pypi, but integration with pytorch cuda on gh200 may still require this override
impact:
- ✅ pytorch cuda operations work fine
- ✅ training and inference work
- ❌
torch.compile()may not be available (depending on triton version) - ❌ some advanced optimizations disabled
for most training/inference workloads, this is acceptable.
package source mapping
[tool.uv.sources]
torch = { index = "pytorch-nightly-cu128" } directs uv to fetch torch from the custom cuda index instead of pypi.
nvidia gh200 architecture
overview
the nvidia gh200 grace hopper superchip combines:
- grace cpu: up to 72 arm neoverse v2 cores (armv9.0-a isa)
- hopper gpu: h100-class gpu with up to 144gb hbm3e memory
- nvlink-c2c: 900 gb/s coherent interconnect between cpu and gpu
specifications
cpu (grace):
- 72 arm neoverse v2 cores
- 4×128-bit simd units per core
- arm scalable vector extensions 2 (sve2)
- up to 480-512 gb lpddr5x memory (546 gb/s bandwidth)
- 117 mb l3 cache
gpu (hopper h100/h200):
- 96gb hbm3 (h100) or 144gb hbm3e (h200)
- 4th generation tensor cores
- 3-4tb/s memory bandwidth
- fp8 precision support for ai workloads
memory architecture:
- cpu: up to 480gb lpddr5x
- gpu: up to 144gb hbm3e
- combined: up to 624gb unified memory space
- nvl2 config (dual superchip): 288gb hbm + 1.2tb total memory
why arm64 matters
unlike traditional x86_64 servers with nvidia gpus, gh200 uses arm64 architecture:
- different instruction set (aarch64 vs x86_64)
- different binary compatibility requirements
- different wheel packages needed from pytorch
- historically limited software ecosystem support
arm64 advantages for ai:
- better power efficiency
- unified memory architecture with gpu via nvlink-c2c
- lower latency cpu-gpu communication
- native support for modern arm features (sve2)
pytorch arm64 cuda support history
timeline
pre-2024: no official support
- pytorch 1.x - 2.6: arm64 wheels were cpu-only
- workarounds required:
- build pytorch from source with cuda
- use nvidia ngc containers
- use community wheels (e.g., kumatee/pytorch-aarch64)
pytorch 2.7 (late 2024): breakthrough release 🎉
- first official arm64 cuda wheels released
- cuda 12.8 support for both x86_64 and aarch64
- hosted on pytorch’s cuda wheel indices
- support for nvidia blackwell gpu architecture
pytorch 2.8-2.9 (2025): maturation
- continued arm64 cuda support
- improved wheel distribution system
- better arm64 optimizations
pytorch 2.10 (current nightly)
- ongoing arm64 cuda improvements
- latest optimizations for grace hopper systems
- available in nightly builds
key insight
arm64 cuda support in pytorch is very recent (< 1 year old as of late 2024). this explains:
- why it’s not well-documented
- why default installations don’t work
- why special configuration is required
- why many developers don’t know about this issue
installation guide
prerequisites
-
nvidia gh200 system with:
- cuda 12.8 installed (
nvcc --version) - nvidia drivers working (
nvidia-smi) - python 3.10+ (python 3.13 recommended)
- cuda 12.8 installed (
-
uv package manager installed:
curl -LsSf https://astral.sh/uv/install.sh | sh
step-by-step installation
step 1: create/update pyproject.toml
add the configuration from the solution section above.
step 2: sync dependencies
for existing projects with torch already installed:
uv sync --reinstall-package torch --prerelease=allow for new projects:
uv sync the --prerelease=allow flag is needed for nightly builds.
step 3: verify installation
downloads ~3gb of cuda libraries on first run:
nvidia-cublas-cu12(563 mb)nvidia-cudnn-cu12(672 mb)nvidia-nccl-cu12(308 mb)nvidia-cusparse-cu12(275 mb)nvidia-cusolver-cu12(255 mb)nvidia-cufft-cu12(184 mb)nvidia-nvshmem-cu12(133 mb)torchitself (535 mb)
expected output:
Installed 17 packages in Xms
+ nvidia-cublas-cu12==12.8.4.1
+ nvidia-cuda-cupti-cu12==12.8.90
+ nvidia-cuda-nvrtc-cu12==12.8.93
+ nvidia-cuda-runtime-cu12==12.8.90
+ nvidia-cudnn-cu12==9.10.2.21
+ nvidia-cufft-cu12==11.3.3.83
+ nvidia-cufile-cu12==1.13.1.3
+ nvidia-curand-cu12==10.3.9.90
+ nvidia-cusolver-cu12==11.7.3.90
+ nvidia-cusparse-cu12==12.5.8.93
+ nvidia-cusparselt-cu12==0.7.1
+ nvidia-nccl-cu12==2.27.5
+ nvidia-nvjitlink-cu12==12.8.93
+ nvidia-nvshmem-cu12==3.4.5
+ nvidia-nvtx-cu12==12.8.90
+ torch==2.10.0.dev20251109+cu128 troubleshooting
still getting cpu-only pytorch
symptom:
$ python -c "import torch; print(torch.__version__)"
2.9.0+cpu # still has +cpu suffix solutions:
- verify index configuration in
pyproject.toml - force reinstall:
uv pip uninstall torch uv sync --reinstall-package torch - check lock file: delete
uv.lockand re-sync:rm uv.lock uv sync
dependency resolution failures
symptom:
error: Distribution `torch==...` can't be installed because it doesn't
have a source distribution or wheel for the current platform solutions:
-
add prerelease flag if using nightly:
uv sync --prerelease=allow -
check python version: must be 3.10-3.13
python --version -
verify architecture:
uname -m # should show: aarch64
pytorch-triton errors
symptom:
Because there is no version of pytorch-triton==3.5.0+git...
and torch==... depends on pytorch-triton... solution:
ensure the override is correct in pyproject.toml:
[tool.uv]
override-dependencies = [
"pytorch-triton ; sys_platform == 'win32'", # note: pytorch-triton, not triton
] import errors after installation
symptom:
ImportError: libcuda.so.1: cannot open shared object file solutions:
-
verify cuda is in library path:
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH -
check cuda installation:
nvcc --version ldconfig -p | grep cuda
out of memory errors
symptom:
RuntimeError: CUDA out of memory. Tried to allocate X MiB solutions:
-
check available memory:
nvidia-smi --query-gpu=memory.free,memory.total --format=csv -
reduce batch size in training configuration
-
enable gradient checkpointing to trade compute for memory
-
use mixed precision (bf16/fp16) to reduce memory usage
performance considerations
expected performance
on gh200 with properly configured cuda pytorch:
- 2-3x faster training vs cpu-only
- full utilization of 96gb/144gb hbm3e gpu memory
- efficient cpu-gpu communication via nvlink-c2c (900 gb/s)
- native fp8 support for transformer models
nightly vs stable
nightly builds (recommended for gh200):
- ✅ latest arm64 optimizations
- ✅ most recent cuda features
- ✅ active development for grace hopper
- ❌ may have occasional bugs
- ❌ less tested than stable releases
stable builds:
- ✅ more thoroughly tested
- ✅ better for production deployments
- ❌ older arm64 optimizations
- ❌ may lack newest features
updating pytorch
to update to newer nightly builds:
uv sync --upgrade-package torch to lock to a specific version, modify pyproject.toml:
dependencies = [
"torch==2.10.0dev20251109+cu128", # specific nightly version
] alternative approaches
nvidia ngc containers
for users who prefer containers over direct installation:
docker run --gpus all -it nvcr.io/nvidia/pytorch:24.02-py3 nvidia provides pre-configured pytorch containers with arm64 support. this is the officially recommended approach for gh200 systems and may be more stable than manual wheel installation.
advantages:
- pre-configured environment
- tested by nvidia
- includes optimized libraries
- no dependency resolution issues
disadvantages:
- larger download size
- less flexible than native installation
- container overhead
building from source
if you need absolute latest features or custom optimizations:
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
export USE_CUDA=1
python setup.py install note: building from source is time-consuming (1-2 hours on gh200) and requires careful dependency management. only recommended for advanced users or when prebuilt wheels are insufficient.
key takeaways
problem: default
pip install torchon arm64 gives cpu-only version, even with nvidia gh200 gpucause: pypi only distributes cpu-only wheels for arm64; cuda wheels on separate pytorch index
- solution: configure uv with custom pytorch cuda index + triton dependency override
timeline: official arm64 cuda support started with pytorch 2.7 (late 2024) - still relatively new
triton issue: pytorch-triton historically lacked arm64 wheels, requiring dependency override workaround
- gh200 specs: 72-core arm cpu + h100/h200 gpu with 900 gb/s nvlink-c2c interconnect
verification: check version shows
+cu128suffix andtorch.cuda.is_available()returnsTrue- download size: first install requires ~3gb of cuda libraries
nightly vs stable: nightly (2.10.0.dev) has latest arm64 optimizations; stable (2.9.0) more tested
alternative: nvidia ngc containers provide pre-configured pytorch environment for gh200
performance: expect 2-3x speedup vs cpu-only with full utilization of 96-144gb gpu memory
references
official pytorch documentation
- pytorch installation guide
- pytorch 2.7 release blog - first official arm64 cuda support
- pytorch cuda wheel index
- pytorch github discussions - gh200 support
nvidia gh200 resources
- nvidia gh200 official page
- gh200 architecture whitepaper
- nvidia developer blog - gh200
- nvidia ngc pytorch container
uv package manager
community resources
- pytorch github issue #123835 - package manager install on gh200
- pytorch github issue #130558 - triton arm64 support request
- pytorch github pr #144049 - add cuda aarch64 triton wheel build
- kumatee/pytorch-aarch64 - community arm64 wheels (pre-2.7)
related wiki pages
- pytorch setup with uv - general pytorch configuration
- pytorch 2.9 release notes - latest stable release
- uv package manager - comprehensive uv guide