pytorch cuda setup for nvidia gh200 (arm64)

on this page
problem:

pip install torch on ARM64 installs CPU-only version, even with NVIDIA GH200 GPU and CUDA 12.8

symptom:torch.cuda.is_available() returns False, version shows 2.9.0+cpu
cause:PyPI only distributes CPU-only wheels for ARM64; CUDA wheels on separate index
solution:configure uv to use pytorch’s cuda index with triton dependency override
key config:

add [[tool.uv.index]] for pytorch cuda wheels + override pytorch-triton for ARM64 compatibility

first official support:pytorch 2.7 (late 2024) - arm64 cuda wheels still relatively new

the problem

what goes wrong

when you install pytorch on an arm64/aarch64 system using standard methods:

pip install torch
# or
uv pip install torch

you get a cpu-only version, even though:

  • your system has an nvidia gh200 gpu (96gb or 144gb hbm3/hbm3e)
  • cuda 12.8 is installed (nvcc --version works)
  • nvidia drivers are working (nvidia-smi shows the gpu)

how to detect the issue

# check cuda availability
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
# ❌ CUDA available: False

# check pytorch version
python -c "import torch; print(torch.__version__)"
# ❌ 2.9.0+cpu  (notice the "+cpu" suffix)

the correct version should show 2.9.0+cu128 or 2.10.0.dev20251109+cu128.

why this happens

pytorch distributes different wheel packages for different platforms:

  1. pypi (default): only cpu-only wheels for arm64
  2. cuda-specific indices: host gpu-enabled wheels for arm64

reasons for this split:

  • arm64 gpu support is recent (pytorch 2.7+, late 2024)
  • cuda wheels are large (500mb+ with bundled cuda libraries)
  • arm64 cuda users still a small minority

the solution

quick install with uv

add this configuration to your pyproject.toml:

[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true

[tool.uv]
override-dependencies = [
    "pytorch-triton ; sys_platform == 'win32'",  # make triton optional on arm64
]

[tool.uv.sources]
torch = { index = "pytorch-nightly-cu128" }

then sync:

uv sync --reinstall-package torch --prerelease=allow

stable pytorch 2.9

for production use, switch to stable 2.9:

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.uv]
override-dependencies = [
    "pytorch-triton ; sys_platform == 'win32'",
]

[tool.uv.sources]
torch = { index = "pytorch-cu128" }

install:

uv sync --reinstall-package torch

verification

test your installation:

# test 1: check cuda availability
uv run python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
# ✅ CUDA available: True

# test 2: check version and device
uv run python -c "
import torch
print(f'PyTorch version: {torch.__version__}')
print(f'CUDA version: {torch.version.cuda}')
print(f'Device count: {torch.cuda.device_count()}')
print(f'Device 0: {torch.cuda.get_device_name(0)}')
"
# ✅ PyTorch version: 2.10.0.dev20251109+cu128
# ✅ CUDA version: 12.8
# ✅ Device count: 1
# ✅ Device 0: NVIDIA GH200 480GB

# test 3: run gpu computation
uv run python -c "
import torch
x = torch.randn(1000, 1000).cuda()
y = x @ x.T
print(f'✓ GPU tensor operation successful')
print(f'Result shape: {y.shape}, device: {y.device}')
"
# ✅ Result shape: torch.Size([1000, 1000]), device: cuda:0

understanding the configuration

custom index definition

[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = true

components:

  • name: arbitrary identifier for the index
  • url: pytorch’s nightly cuda 12.8 wheel repository
  • explicit = true: only use this index for explicitly configured packages (not transitive dependencies like numpy)

available indices:

versionurldescription
nightly cu128https://download.pytorch.org/whl/nightly/cu128latest features (2.10.0.dev)
stable 2.9https://download.pytorch.org/whl/cu128pytorch 2.9.0+cu128
stable 2.8https://download.pytorch.org/whl/cu126pytorch 2.8.0+cu126 (cuda 12.6)

triton dependency override

[tool.uv]
override-dependencies = [
    "pytorch-triton ; sys_platform == 'win32'",
]

what this does:

tells uv to only require pytorch-triton on windows, effectively making it optional on arm64 linux.

why needed:

  • triton only provides x86_64 wheels (no arm64 wheels available)
  • pytorch depends on triton for torch.compile() optimizations
  • update (2025): triton 3.5.0 now has arm64 wheels on pypi, but integration with pytorch cuda on gh200 may still require this override

impact:

  • ✅ pytorch cuda operations work fine
  • ✅ training and inference work
  • torch.compile() may not be available (depending on triton version)
  • ❌ some advanced optimizations disabled

for most training/inference workloads, this is acceptable.

package source mapping

[tool.uv.sources]
torch = { index = "pytorch-nightly-cu128" }

directs uv to fetch torch from the custom cuda index instead of pypi.

nvidia gh200 architecture

overview

the nvidia gh200 grace hopper superchip combines:

  • grace cpu: up to 72 arm neoverse v2 cores (armv9.0-a isa)
  • hopper gpu: h100-class gpu with up to 144gb hbm3e memory
  • nvlink-c2c: 900 gb/s coherent interconnect between cpu and gpu

specifications

cpu (grace):

  • 72 arm neoverse v2 cores
  • 4×128-bit simd units per core
  • arm scalable vector extensions 2 (sve2)
  • up to 480-512 gb lpddr5x memory (546 gb/s bandwidth)
  • 117 mb l3 cache

gpu (hopper h100/h200):

  • 96gb hbm3 (h100) or 144gb hbm3e (h200)
  • 4th generation tensor cores
  • 3-4tb/s memory bandwidth
  • fp8 precision support for ai workloads

memory architecture:

  • cpu: up to 480gb lpddr5x
  • gpu: up to 144gb hbm3e
  • combined: up to 624gb unified memory space
  • nvl2 config (dual superchip): 288gb hbm + 1.2tb total memory

why arm64 matters

unlike traditional x86_64 servers with nvidia gpus, gh200 uses arm64 architecture:

  • different instruction set (aarch64 vs x86_64)
  • different binary compatibility requirements
  • different wheel packages needed from pytorch
  • historically limited software ecosystem support

arm64 advantages for ai:

  • better power efficiency
  • unified memory architecture with gpu via nvlink-c2c
  • lower latency cpu-gpu communication
  • native support for modern arm features (sve2)

pytorch arm64 cuda support history

timeline

pre-2024: no official support

  • pytorch 1.x - 2.6: arm64 wheels were cpu-only
  • workarounds required:
    • build pytorch from source with cuda
    • use nvidia ngc containers
    • use community wheels (e.g., kumatee/pytorch-aarch64)

pytorch 2.7 (late 2024): breakthrough release 🎉

  • first official arm64 cuda wheels released
  • cuda 12.8 support for both x86_64 and aarch64
  • hosted on pytorch’s cuda wheel indices
  • support for nvidia blackwell gpu architecture

pytorch 2.8-2.9 (2025): maturation

  • continued arm64 cuda support
  • improved wheel distribution system
  • better arm64 optimizations

pytorch 2.10 (current nightly)

  • ongoing arm64 cuda improvements
  • latest optimizations for grace hopper systems
  • available in nightly builds

key insight

arm64 cuda support in pytorch is very recent (< 1 year old as of late 2024). this explains:

  • why it’s not well-documented
  • why default installations don’t work
  • why special configuration is required
  • why many developers don’t know about this issue

installation guide

prerequisites

  1. nvidia gh200 system with:

    • cuda 12.8 installed (nvcc --version)
    • nvidia drivers working (nvidia-smi)
    • python 3.10+ (python 3.13 recommended)
  2. uv package manager installed:

    curl -LsSf https://astral.sh/uv/install.sh | sh

step-by-step installation

step 1: create/update pyproject.toml

add the configuration from the solution section above.

step 2: sync dependencies

for existing projects with torch already installed:

uv sync --reinstall-package torch --prerelease=allow

for new projects:

uv sync

the --prerelease=allow flag is needed for nightly builds.

step 3: verify installation

downloads ~3gb of cuda libraries on first run:

  • nvidia-cublas-cu12 (563 mb)
  • nvidia-cudnn-cu12 (672 mb)
  • nvidia-nccl-cu12 (308 mb)
  • nvidia-cusparse-cu12 (275 mb)
  • nvidia-cusolver-cu12 (255 mb)
  • nvidia-cufft-cu12 (184 mb)
  • nvidia-nvshmem-cu12 (133 mb)
  • torch itself (535 mb)

expected output:

Installed 17 packages in Xms
 + nvidia-cublas-cu12==12.8.4.1
 + nvidia-cuda-cupti-cu12==12.8.90
 + nvidia-cuda-nvrtc-cu12==12.8.93
 + nvidia-cuda-runtime-cu12==12.8.90
 + nvidia-cudnn-cu12==9.10.2.21
 + nvidia-cufft-cu12==11.3.3.83
 + nvidia-cufile-cu12==1.13.1.3
 + nvidia-curand-cu12==10.3.9.90
 + nvidia-cusolver-cu12==11.7.3.90
 + nvidia-cusparse-cu12==12.5.8.93
 + nvidia-cusparselt-cu12==0.7.1
 + nvidia-nccl-cu12==2.27.5
 + nvidia-nvjitlink-cu12==12.8.93
 + nvidia-nvshmem-cu12==3.4.5
 + nvidia-nvtx-cu12==12.8.90
 + torch==2.10.0.dev20251109+cu128

troubleshooting

still getting cpu-only pytorch

symptom:

$ python -c "import torch; print(torch.__version__)"
2.9.0+cpu  # still has +cpu suffix

solutions:

  1. verify index configuration in pyproject.toml
  2. force reinstall:
    uv pip uninstall torch
    uv sync --reinstall-package torch
  3. check lock file: delete uv.lock and re-sync:
    rm uv.lock
    uv sync

dependency resolution failures

symptom:

error: Distribution `torch==...` can't be installed because it doesn't
have a source distribution or wheel for the current platform

solutions:

  1. add prerelease flag if using nightly:

    uv sync --prerelease=allow
  2. check python version: must be 3.10-3.13

    python --version
  3. verify architecture:

    uname -m  # should show: aarch64

pytorch-triton errors

symptom:

Because there is no version of pytorch-triton==3.5.0+git...
and torch==... depends on pytorch-triton...

solution:

ensure the override is correct in pyproject.toml:

[tool.uv]
override-dependencies = [
    "pytorch-triton ; sys_platform == 'win32'",  # note: pytorch-triton, not triton
]

import errors after installation

symptom:

ImportError: libcuda.so.1: cannot open shared object file

solutions:

  1. verify cuda is in library path:

    export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
  2. check cuda installation:

    nvcc --version
    ldconfig -p | grep cuda

out of memory errors

symptom:

RuntimeError: CUDA out of memory. Tried to allocate X MiB

solutions:

  1. check available memory:

    nvidia-smi --query-gpu=memory.free,memory.total --format=csv
  2. reduce batch size in training configuration

  3. enable gradient checkpointing to trade compute for memory

  4. use mixed precision (bf16/fp16) to reduce memory usage

performance considerations

expected performance

on gh200 with properly configured cuda pytorch:

  • 2-3x faster training vs cpu-only
  • full utilization of 96gb/144gb hbm3e gpu memory
  • efficient cpu-gpu communication via nvlink-c2c (900 gb/s)
  • native fp8 support for transformer models

nightly vs stable

nightly builds (recommended for gh200):

  • ✅ latest arm64 optimizations
  • ✅ most recent cuda features
  • ✅ active development for grace hopper
  • ❌ may have occasional bugs
  • ❌ less tested than stable releases

stable builds:

  • ✅ more thoroughly tested
  • ✅ better for production deployments
  • ❌ older arm64 optimizations
  • ❌ may lack newest features

updating pytorch

to update to newer nightly builds:

uv sync --upgrade-package torch

to lock to a specific version, modify pyproject.toml:

dependencies = [
    "torch==2.10.0dev20251109+cu128",  # specific nightly version
]

alternative approaches

nvidia ngc containers

for users who prefer containers over direct installation:

docker run --gpus all -it nvcr.io/nvidia/pytorch:24.02-py3

nvidia provides pre-configured pytorch containers with arm64 support. this is the officially recommended approach for gh200 systems and may be more stable than manual wheel installation.

advantages:

  • pre-configured environment
  • tested by nvidia
  • includes optimized libraries
  • no dependency resolution issues

disadvantages:

  • larger download size
  • less flexible than native installation
  • container overhead

building from source

if you need absolute latest features or custom optimizations:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
export USE_CUDA=1
python setup.py install

note: building from source is time-consuming (1-2 hours on gh200) and requires careful dependency management. only recommended for advanced users or when prebuilt wheels are insufficient.

key takeaways

  • problem: default pip install torch on arm64 gives cpu-only version, even with nvidia gh200 gpu

  • cause: pypi only distributes cpu-only wheels for arm64; cuda wheels on separate pytorch index

  • solution: configure uv with custom pytorch cuda index + triton dependency override
  • timeline: official arm64 cuda support started with pytorch 2.7 (late 2024) - still relatively new

  • triton issue: pytorch-triton historically lacked arm64 wheels, requiring dependency override workaround

  • gh200 specs: 72-core arm cpu + h100/h200 gpu with 900 gb/s nvlink-c2c interconnect
  • verification: check version shows +cu128 suffix and torch.cuda.is_available() returns True

  • download size: first install requires ~3gb of cuda libraries
  • nightly vs stable: nightly (2.10.0.dev) has latest arm64 optimizations; stable (2.9.0) more tested

  • alternative: nvidia ngc containers provide pre-configured pytorch environment for gh200

  • performance: expect 2-3x speedup vs cpu-only with full utilization of 96-144gb gpu memory

references

official pytorch documentation

nvidia gh200 resources

uv package manager

community resources

on this page