分類
未分類

如何在 Windows 上應用最新版的 Tensorflow 訓練模型?你需要 WSL v2

最近在學習使用 Tensorflow 建立 AI 模型。Google Colab 免費也很好用,離峰時段還能使用 GPU or TPU。但是有時候就會想在自己的機器上利用 GPU 來建模。Tensorflow 官方安裝網站 https://www.tensorflow.org/install/pip 提到 Windows 原生支援只到 v2.10 而已。而撰寫本文時最新版本已經來到 v2.16,難道 Windows 使用者就得用虛擬機器跑 Linux 或是尋找其他雲端資源嗎?

不必。只要你的 Windows 可以跑 WSL v2 就能跑最新版的 Tensorflow!

https://www.tensorflow.org/install/pip 有提到,Windows 10 64-bit 版本 19044 或更新版,或是 Windows 11 都支援 WSL v2。在 WSL v2 中,可以選擇運行 Ubuntu 20.04 LTS 或是 Ubuntu 22.04 LTS,都是能執行 Tensorflow 最新版的環境。

這邊我用我的 Surface Book 3 來示範一次。Surface Book 3 搭載一張 NVIDIA GTX 1660 Ti 獨立顯示卡,有 6GB 的 VRAM。雖然不大但還能拿來訓練模型。

要建置環境,首先要先確認安裝 WSL v2 以及 Ubuntu。這邊我選用的是 22.04 LTS。安裝 WSL v2 的方法請自行尋找。

裝好 WSL v2 以及 Ubuntu 環境後,首先要先確認 NVIDIA 驅動程式是否正常運作。啟動 WSL 後執行

nvidia-smi

應該可以得到類似下圖的畫面

我之前曾經遇過一個狀況:執行 nvidia-smi 後,出現了 segmentation fault 的錯誤訊息,類似下面這樣:

雖然我繼續安裝了 tensorflow,最後依然找不到 GPU。當時我的 laptop 安裝的是 NVIDIA 官網下載的最新版驅動程式。後來解決的方法是:

  1. 先使用 Display driver uninstaller 移除 NVIDIA 驅動程式
  2. 然後讓 Windows Update 安裝驅動程式。
  3. 由於 Windows Update 提供的驅動程式版本只有到 51x,實在是太舊了,我就又上了 NVIDIA 官網,找一個稍微舊一點的版本(也就是 546.29)來用。安裝時我選擇快速安裝(Express installation)。完成後在 WSL 內執行 nvidia-smi 得到的結果就是正常的。

接下來就是按照 https://www.tensorflow.org/install/pip#step-by-step_instructions 的介紹依序安裝:

  1. CUDA ToolKit。安裝 CUDA toolkit 要注意版本對應。可以參考 https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id5 例如說要安裝 CUDA toolkit v12 以上的話,驅動程式版本要大於等於 525.60.13。安裝的方法就是到 https://developer.nvidia.com/cuda-toolkit-archive 找你要安裝的版本,點擊該版本的連結後,選 Linux,你的平台格式(既然是 Windows,應該就是 x86_64),然後 WSL-Ubuntu。最後安裝格式根據自己需求選擇。選擇完畢後,該頁面會提供安裝指引,照著做即可。
  1. 下一步,要安裝 cuDNN SDK。到 https://developer.nvidia.com/cudnn 。這邊就不像 CUDA toolkit 還要選版本,直接點 Download cuDNN library 就行了。在 distrubution 選 Ubuntu 就好,然後 version 選擇自己安裝的 Ubuntu 版本。最下面安裝指引也有提醒如何安裝符合特定 CUDA Toolkit 版本的 cuDNN。
  1. 執行下列指令,確保 pip 是最新版
pip install --upgrrade pip
  1. 執行下列指令,安裝支援 GPU 運算的 tensorflow
pip install tensorflow[and-cuda]
  1. 安裝完成後,執行下列指令檢查輸出結果
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"

如果一切正常,應該會得到類似下圖的輸出

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *