# Google TPU 101

对于习惯了 CUDA 的工程师来说，对 GPU 的认知往往建立在 SM、Warp 调度、寄存器和 Shared Memory 之上。开发者通过 Kernel、Grid/Block 配置和 shared memory tile 化，挖掘硬件算力，关注的核心是「每个线程在干什么」、「访存是否 coalesced」以及「Warp 的分支发散情况」。

TPU 要解决的问题和 GPU 类似：高吞吐地跑深度学习推理和训练。但 Google 并没有「再造一个 GPU」，而是从头围绕深度学习工作负载做了一个几乎完全不同的架构 [1,2]。这套架构，从硬件到软件，都围绕两个核心原则展开：深度学习的主力工作就是大规模矩阵乘加（MatMul/GEMM）加上少量逐元素运算，以及模型规模持续变大，多芯片协同是常态而不是例外。

围绕这两个原则，TPU 在硬件和软件层面做了两组非常鲜明的取舍。在硬件层面，TPU 放弃了复杂的 SIMT + Warp 调度机制、传统通用 Cache 层级和精细可编程的 shared memory，以及面向图形和通用计算的指令级灵活性。取而代之，TPU 极致强化了固定规模的 systolic array 矩阵乘单元（MXU）、高带宽弱 Cache 语义的 HBM 加本地 scratchpad，以及芯片间高带宽面向 collective 的专用互联。

在软件层面，TPU 同样放弃了以 Kernel 为中心、由工程师手工指定 launch 配置的编程模型，以及显式管理 shared memory、warp 和指令序列的「手工调度权」。相应地，TPU 极致强化了以 **XLA** 为核心的图级编译器 [3]，聚焦 HLO / StableHLO 层的跨框架可移植性 [3,5]，以及针对特定硬件（TPU Pod）的全局图优化和集体通信调度。

所以，本文的两条主线是：

- 从硬件和软件两个维度，回答：**TPU 在架构上「放弃了什么」，又「极致强化了什么」？**
- 从工程实践的角度，回答：**如果你已经会写 CUDA MatMul，那么在 TPU 上 MatMul 是如何「被写出来」的？**

阅读时可以带着一个简单的对照问题：在 GPU 上，你优化的是「线程如何跑得更快」；在 TPU 上，你则要学会优化「矩阵如何更顺畅地流过整个系统」。

---

## 1. TPU 的设计出发点：为什么 Google 没有“再造一个 GPU”

### 1.1 从 GPU 的“通用性负担”说起

GPU 的成功很大程度上来自「通用性」：既能跑图形，又能跑科学计算、视频编解码和深度学习。为了支撑这种通用性，GPU 架构天然背负了不少历史包袱。

首先是图形流水线遗产，早期 GPU 为图形渲染而生，Shader 模型、纹理单元、固定功能管线等设计长期影响着硬件形态。其次是 SIMT + Warp 调度复杂度，单指令多线程（SIMT）模型需要硬件负责 Warp 形成、切换和分支掩码管理，这带来了复杂的控制逻辑。最后是 Cache、Shared Memory 和 Register 的多层次管理成本，为了兼顾不同访问模式，GPU 掌握着多级 Cache、软件管理的 shared memory 和大规模寄存器文件，程序员需要在这些层级之间不断做权衡。

但对现代深度学习工作负载而言，需求抽象其实要简单得多 [2]。主力是大规模 dense 或 semi-dense 的矩阵乘加（MatMul、Conv），数据流大多是可预测的，很少出现完全随机、无规律的访问模式，控制流复杂度相对较低，长而规整的算子链条非常常见。换句话说，GPU 为了通用性保留了大量「可能会用到」的能力，但深度学习主力工作负载并不总是需要这些能力，甚至在某些场景下，这些通用机制本身会成为优化的负担。

### 1.2 TPU 的根本假设

基于在 Google 内部大规模深度学习工作负载的经验，TPU 团队做了一个激进但清晰的假设 [1,2]：

- **训练 / 推理 ≈ 大规模矩阵乘加 + 少量逐元素运算。**

在这个假设下，TPU 在设计上与 GPU 产生了根本分歧。GPU 主要优化「如何让线程跑得更快」，重点在于单个线程或 Warp 的效率、指令调度和访存 coalescing。而 TPU 则优化「如何让矩阵一直在流动」，重点在于 systolic array 的填满率、数据进入和流出阵列的效率，以及在整个 Pod 级别让矩阵不断流动。

这直接驱动了后续硬件和软件的取舍：

- 在硬件上，TPU 可以牺牲灵活性，换取为矩阵乘专门打造的 systolic array 和数据流路径。
- 在软件上，TPU 可以把「调度权」交给编译器，让工程师重点描述计算图本身，而不是 Kernel 级别的执行细节。

---

## 2. 硬件架构对比：SM + Tensor Core vs MXU + Systolic Array

### 2.1 GPU 架构快速回顾（作为对照系）

先快速回顾一下 GPU 结构，作为后文的对照系。

在一个典型的 NVIDIA GPU 上：

- SM 的核心职责包括：

  - 负责 Warp 形成和调度；
  - 负责指令的发射与流水线控制；
  - 协调 Tensor Core 与标量 / 向量单元共同工作。

- Tensor Core 的工作方式可以概括为：
  - Tile-based MMA：以固定大小的矩阵 tile 为基本操作单元；
  - 通过 WMMA、CUTLASS 或编译器自动生成的指令来触发；
  - Kernel 作者需要做大量 tile 划分和 shared memory 管理，才能把 Tensor Core 利用率拉起来。

从这个视角看，GPU 是一个典型的 **线程中心（thread-centric）** 架构：一切优化几乎都可以还原为「某个线程 / Warp 在什么时间访问了什么地址、发出了什么指令」。

---

### 2.2 TPU 核心计算单元：MXU（Matrix Multiply Unit）

在 TPU 上，核心计算单元不再被称为 Tensor Core，而是 Matrix Multiply Unit（MXU）。MXU 并不是简单放大版的 Tensor Core，而是围绕 systolic array 专门设计的矩阵乘机器 [1,2]。

MXU 的关键特征包括：

- 固定规模的 **systolic array**（典型配置为 128 × 128 ALU 阵列）[1,2]：
  - 每个阵列元素负责简单的乘加（MAC）操作；
  - 数据沿着阵列的行、列「心跳式」推进。
- 硬件级数据推进（dataflow in hardware）：
  - 阵列内的数据流由硬件自动推动，而不是依赖每个线程显式 load/store；
  - 一旦数据被送入阵列，之后的传播和累加几乎不需要指令介入。
- 几乎没有指令级调度概念：
  - MXU 更像一个「一次性启动的大型流水线」，而不是频繁发射的指令流；
  - 对编程模型来说，MXU 是不可见的，你看不到「某条矩阵乘指令」。
- 一个 TPU TensorCore 由 MXU + 向量单元 + 标量单元组成 [1]：
  - MXU 负责大块矩阵乘；
  - 向量单元负责逐元素运算、简单非线性；
  - 标量单元负责控制流、地址生成和数据搬运。

这组设计的直接结果是：TPU 在硬件上 **放弃了线程级别的精细控制权**，换取了一个为大规模矩阵乘专门铺好的流水线。

从对比视角来看，GPU Tensor Core 和 TPU MXU 在多个维度上存在显著差异。在调度单位方面，GPU 采用 Warp 或 Instruction 级别的调度，而 TPU 则采用数据流调度方式。编程可见性方面，GPU 的 Tensor Core 对程序员是可感知的，而 TPU 的 MXU 对编程模型来说是完全不可见的。在 Tile 控制方面，GPU 需要程序员或 XLA 进行控制，而 TPU 则由纯硬件控制。最终目标方面，GPU 主要关注提升单次 MMA 性能，而 TPU 则致力于最大化持续吞吐。

---

### 2.3 Systolic Array：TPU 的“第一性硬件结构”

为什么 TPU 会选择 systolic array，而不是继续沿用 GPU 的 SIMT + Tensor Core 模型？

从硬件的「第一性原理」来看，矩阵乘的核心在于重复的 MAC 操作和可预测的数据访问模式。systolic array 正好满足这两个特点：

- 每个阵列单元只做非常简单的乘加；
- 数据可以沿着规则的方向以固定节奏推进；
- 同一个数据在阵列中被多次复用，而不需要反复访问外部存储。

如果用数据流来对比 GPU 和 TPU：

- 在 GPU 上，一般的模式是：

  - Load → Compute → Store（循环往复）；
  - 每一个 tile 的输入、输出都需要显式 load/store；
  - 访存模式、shared memory 复用都需要工程师手工设计。

- 在 TPU 上，systolic array 的典型模式更接近：

  - **Load once → Flow through → Accumulate**；
  - 数据在阵列中不断向前流动，沿途经过多个 MAC 单元并被累加；
  - 一旦进入阵列，数据在阵列内部的传播几乎不再消耗外部带宽。

从「放弃与强化」的角度看：

- TPU 放弃了针对任意访问模式的通用负载能力；
- TPU 强化了针对规则数据流的极致吞吐和能效。

---

### 2.4 内存与数据通路设计

TPU 的内存层次和 GPU 看起来类似：都有 HBM，也有本地更快的存储。但设计哲学完全不同 [1]。

TPU 的内存哲学可以概括为采用高带宽 HBM 作为权重和激活的大容量存储，每个 TensorCore 拥有本地向量内存（VMEM / scratchpad）并由编译器显式管理 [1]，采用极弱的 Cache 语义尽量避免「硬件猜测」，以及在编译期确定数据布局与数据移动路径，尽可能在编译阶段决定「哪些数据何时放在哪」。

对比之下，GPU 的做法是把片上 SRAM 分为 shared memory 和各种 Cache，Cache miss 是常态需要依赖硬件猜测访问模式，工程师需要在 shared memory、寄存器和全局内存之间做精细调度。

在这里，TPU 放弃的是「硬件自动猜测访问模式」的便利，而强化的是「由编译器全局规划数据流」的可预测性。

### 2.5 SparseCore：从 v4 起的稀疏加速单元

从 TPU v4 开始，Google 在 TPU 芯片中加入了一类新的硬件单元：SparseCore [1]。它不是 MXU 的替代，而是专门为大规模推荐和 embedding 模型设计的稀疏加速器。

SparseCore 的定位是面向大规模推荐和 embedding 模型的稀疏操作加速 [1]，与 MXU 形成互补关系：MXU 负责 dense 和 semi-dense 计算，而 SparseCore 负责高度稀疏访问。

在架构设计上，SparseCore 采用 tile-based dataflow 处理器架构，每个 tile 拥有本地 SPMEM 与处理单元 [1]。数据布局与缓存策略由 XLA 或上层软件显式控制，而非依赖硬件自动缓存。通过分析 embedding 访问模式，系统能够决定哪些向量缓存于本地 SPMEM，哪些驻留在 HBM 中 [1]。

可以看到，哪怕是在面向稀疏访问模式的硬件设计上，TPU 也依然坚持「由软件（编译器）显式决定数据位置」的设计思想。

---

## 3. TPU 芯片 ≠ TPU 系统：从单卡到 TPU Pod

### 3.1 TPU 的 scale-first 设计

TPU 在设计之初就假设模型会持续变大，多芯片协同训练是常态而不是例外。因此，TPU 的设计采用了 **scale-first** 的理念。在单芯片层面，TPU 追求在给定功耗和面积约束下最大化矩阵吞吐；而在多芯片层面，则通过专用互联（ICI）和 Pod 拓扑，把多个 TPU 打造成一个「系统级加速器」。

与之对应，在 GPU 世界里，单卡性能和灵活性通常被放在首位，多卡或多节点扩展更多依赖软件栈（如 NCCL、Horovod 等）来弥补。在互联技术方面，TPU 使用的是 ICI（Inter-Chip Interconnect），而 GPU 集群中更常见的是 NVLink 或 InfiniBand。虽然两者都提供高带宽低延迟通信，但 TPU 的 ICI 从设计之初就和 collective 通信紧密绑定 [1,2]。

### 3.2 TPU Pod 的系统级抽象

在系统层面，一个 TPU Pod 是按照 2D/3D torus 拓扑组织起来的 [1]。每个 TPU 主板包含多个芯片，主板之间通过 ICI 连接，全 Pod 的 ICI 形成一个规则的二维或三维环形拓扑。在这种设计下，collective 操作（例如 AllReduce、AllGather）在硬件层就是一等公民。

与 GPU 集群相比，这种设计理念存在显著差异。GPU 集群通常表现为节点级能力强，但系统级更多依赖软件栈拼凑起来；而 TPU 则将系统级作为一等公民，Pod 自身就是一个面向大规模训练的整体系统。

这背后对应的设计取舍是：GPU 强化的是单卡通用算力和弹性软件生态，而 TPU 强化的是系统级矩阵吞吐和深度学习专用互联。

### 3.3 Cloud TPU 部署模型与拓扑细节

在 Cloud TPU 上，Google 提供的是以 TPU VM 为基础的编程模型 [2]。TPU VM 直接将 TPU 绑定到一个虚拟机实例上，提供「类本地设备」体验 [2]，而 Pod 配置在物理上提供多芯片、跨机柜的 2D 或 3D torus 拓扑。

ICI 的特性包括面向 AllReduce 和 collective 通信优化的专用互联 [1]，其延迟、带宽与拓扑设计都假设「多芯片同步训练」是最主要场景。

对工程实践而言，这种设计意味着编程模型天然偏向数据并行和张量切分，而不是「多进程多 GPU 自己拼」[2,4]。性能优化更偏向于如何把计算图高效地分布到整个 Pod 上，而不是单纯追求单卡把 FLOPS 吃满。

这也是 TPU 在系统层面「放弃与强化」的典型体现：放弃对杂乱多样拓扑的兼容性，强化对特定拓扑下大规模同步训练的极致支持。

---

## 4. 软件栈总览：从 CUDA 到 XLA 的范式跃迁

### 4.1 CUDA 软件栈回顾

在 CUDA 世界里，软件栈大致可以拆成三层。CUDA Runtime 和 Driver 负责设备管理、内存分配、Kernel 启动、上下文切换等基础功能。Kernel 编程模型让工程师使用 C++ 或 CUDA C 编写 Kernel，显式指定 Grid 和 Block 配置。手工性能调优路径包括修改 block size、grid 维度、线程布局，手工 tile 化数据并搬到 shared memory，以及使用 WMMA 或 CUTLASS 映射到 Tensor Core。

在这个模型下，开发者拥有极大的控制权，几乎能决定每个线程在每个时钟周期做什么，但这也意味着必须为性能负全责。

### 4.2 TPU 软件栈的核心：XLA

在 TPU 世界里，XLA 是整个软件栈的核心 [3]。它是一个面向深度学习的编译器，可以做 ahead-of-time 或 just-in-time 编译。XLA 的定位是从 TensorFlow、JAX、PyTorch/XLA 等前端接收计算图 [2,3,4]，以 HLO 或 StableHLO 作为中间表示进行图级优化 [3,5]，并针对 TPU、GPU、CPU 等不同后端生成高度优化的执行代码。

XLA 的优化与执行路径包括将前端计算图 lower 为 HLO 或 StableHLO 表示，在图级别进行算子融合（Fusion）、内存布局优化和并行划分，把矩阵计算切分成适配 MXU 的 tile（例如 128 × 128）[2]，以及针对特定后端生成目标代码并缓存编译结果以复用 [3,4]。

在 TPU 上，不存在很多开发者已经习惯的概念，例如「TPU Kernel」、「TPU shared memory」或「TPU launch configuration」。这些都被 XLA 收编了，开发者在 TPU 上编写的是「计算图」而不是「Kernel」。

### 4.3 主流前端栈：TensorFlow / JAX / PyTorch XLA

今天在 Cloud TPU 上常见的前端主要有三类。TensorFlow + XLA 通过 `tf.function(jit_compile=True)` 或编译标志启用 XLA [2]，官方文档建议使用静态 shape、规则 batch 和规整张量布局以匹配 MXU 偏好 [2]。

JAX 本质上以 XLA 为默认编译后端 [3]，通过 `jit`、`pmap`、`pjit` 等原语控制编译和分布式并行。

PyTorch / XLA 通过 Lazy Tensor 机制收集 IR 图，再下发到 XLA 编译 [4]，尽可能保留 PyTorch 的使用体验，同时让模型跑在 Cloud TPU 或 XLA GPU 上 [4]。

### 4.4 HLO / StableHLO 与可移植性

在 XLA 内部，计算图会被表示为 HLO（High Level Operations）[3]。HLO 描述算子、shape、布局以及它们之间的数据依赖，是图级优化的工作对象而不是单条指令。

StableHLO 则是在 HLO 之上的一个稳定规范 [5]，面向多框架和多编译器的可移植 HLO 规格，目标是在 TensorFlow、JAX、PyTorch 等前端与 XLA、IREE 等后端之间建立稳定接口。

对 CUDA 工程师而言，这意味着需要从「写 Kernel + Launch Config」升级到「理解编译后图长什么样」，性能调优更多发生在「图级别变换 / 结构化建模」而不是单 Kernel 拆解。

换句话说，TPU 在软件层面放弃了对每个线程行为的精细控制权，强化了对整个计算图的全局掌控力。

---

## 5. 矩阵乘法对照解析：CUDA vs TPU

> 本节是全文的「认知锚点」，目标是让开发者 **真正理解 TPU 的编程方式差异**。

本节将以典型的矩阵乘法（MatMul）为例，对比「在 CUDA 里控制了什么」与「在 TPU 上放弃了什么」。

---

### 5.1 CUDA 中的 MatMul：开发者在控制什么？

假设要在 CUDA 上实现一个简单的矩阵乘 `C = A x B`，即使不使用 WMMA，也通常会采用类似下面的 Kernel 结构：

```cpp
// 一个典型的 CUDA 矩阵乘 Kernel，展示开发者手工控制的要素
__global__ void matmul_kernel(const float* A, const float* B, float* C,
                              int M, int N, int K) {
  // 计算当前线程负责的输出位置
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  // 使用 shared memory 做 tile 化加载
  __shared__ float Asub[BLOCK_SIZE][BLOCK_SIZE];
  __shared__ float Bsub[BLOCK_SIZE][BLOCK_SIZE];

  float acc = 0.0f;
  for (int t = 0; t < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++t) {
    // 每个线程负责加载部分 tile 数据
    int tiledRow = row;
    int tiledCol = t * BLOCK_SIZE + threadIdx.x;
    if (tiledRow < M && tiledCol < K) {
      Asub[threadIdx.y][threadIdx.x] = A[tiledRow * K + tiledCol];
    } else {
      Asub[threadIdx.y][threadIdx.x] = 0.0f;
    }

    tiledRow = t * BLOCK_SIZE + threadIdx.y;
    tiledCol = col;
    if (tiledRow < K && tiledCol < N) {
      Bsub[threadIdx.y][threadIdx.x] = B[tiledRow * N + tiledCol];
    } else {
      Bsub[threadIdx.y][threadIdx.x] = 0.0f;
    }

    __syncthreads();

    // 在共享内存 tile 上做局部乘加
    for (int k = 0; k < BLOCK_SIZE; ++k) {
      acc += Asub[threadIdx.y][k] * Bsub[k][threadIdx.x];
    }

    __syncthreads();
  }

  // 把结果写回全局内存
  if (row < M && col < N) {
    C[row * N + col] = acc;
  }
}
```

在这个 Kernel 里，开发者显式控制了：

- Grid / Block 划分：`blockDim`、`gridDim`，决定每个 block 处理多大 tile；
- Tile 大小选择：`BLOCK_SIZE`；
- Shared memory 使用：哪些数据放 shared memory，如何复用；
- 每个线程负责哪些 load/store，以及怎样避免越界；
- 指令层面是否使用 WMMA / Tensor Core。

本质上，开发者在 **显式描述执行策略**：如何在给定硬件上完成这次矩阵乘。

---

### 5.2 TPU 中的 MatMul：开发者“放弃”了什么？

在 TPU 上，同样的矩阵乘可以写成非常简单的形式，例如在 JAX 中：

```python
# 在 JAX + XLA + TPU 上执行矩阵乘，展示开发者只描述数学运算
import jax
import jax.numpy as jnp

def matmul(a, b):
  # 只描述数学运算本身
  return a @ b

# 使用 jit 触发 XLA 编译
jit_matmul = jax.jit(matmul)

# 在 TPU 设备上创建输入张量并执行
a = jnp.ones((M, K), dtype=jnp.float32)
b = jnp.ones((K, N), dtype=jnp.float32)
c = jit_matmul(a, b)  # 实际在 TPU MXU 上执行
```

从这个调用方式看，开发者似乎「什么都没做」：没有 Kernel，没有 blockDim / gridDim，没有 shared memory 的显式声明，更没有任何 warp 相关的参数。

但在背后，XLA 做了很多在 CUDA 世界里需要手工完成的事情 [2,3,4]：

- Shape 推断：根据 `a` 和 `b` 的 shape 推断矩阵乘的具体维度；
- Tile 拆分：把矩阵乘拆解为适合 MXU 的 tile（典型是 128 × 128）[2]；
- Systolic array 映射：决定每个 tile 如何映射到 systolic array 中；
- 数据流调度：决定数据如何从 HBM / VMEM 流入 MXU，又如何被累加和写回。

本质上，开发者在 TPU 上 **放弃了 Kernel 级的执行权**，只描述「数学本身」，把执行策略全部交给编译器。

这背后还有一个重要细节：与硬件偏好的对齐 [2]。

- MXU 一般以 128 × 128 tile 为基本单位进行计算 [2]；
- XLA 会在必要时自动 padding 到硬件友好的尺寸（例如 8、128 的倍数）[2]；
- 高性能 Cloud TPU 程序通常要求 batch size 或特征维度满足这些对齐约束，避免大量无效 padding [2]。

这意味着，**开发者仍然在为性能做决策**，但决策的对象从「线程 / block / shared memory」变成了「shape / batch size / 数据布局」。

---

### 5.3 思维模型对比总结

| **问题**     | **CUDA**    | **TPU**    |
| ------------ | ----------- | ---------- |
| 谁决定 tile  | 程序员      | XLA        |
| 谁调度计算   | Kernel      | 编译器     |
| 优化粒度     | 指令 / Warp | 计算图     |
| 性能调优方式 | 手工        | 结构化建模 |

### 5.4 一个 Conv + BN + ReLU 的对照 case（将在正文展开）

再来看一个更复杂但更贴近日常训练的算子组合：Conv + BN + ReLU。

在 CUDA 视角下，三个算子往往对应多个 Kernel，期间存在多次 HBM 读写。开发者会尝试做 Kernel fusion，把 BN 和 ReLU 融到 Conv kernel 里，需要显式考虑共享内存复用、访存 coalescing 和寄存器压力。

在 TPU / XLA 视角下，XLA 倾向将 Conv + BN + ReLU 融合为一个或少量 HLO / Fusion 区域 [3,4]。编译器负责在 MXU 和向量单元之间安排数据流路径、临时结果存放位置，工程师更多从「图结构是否易于融合」的角度改写模型，而不是手写融合后的 kernel。

这里再次体现出 TPU 的软件取舍：放弃对每个算子如何落在硬件上的直接控制，强化对整个算子图进行统一调度和融合的能力。

---

## 6. TPU 的优势、限制与适用边界

### 6.1 TPU 的“甜点区”

综合硬件和软件的取舍，TPU 的「甜点区」非常清晰 [1,2]：

- 大模型训练：尤其是 Transformer 类模型，在大规模 Pod 上的数据并行或张量并行；
- 高吞吐推理：对单样本延迟不敏感、对吞吐敏感的服务场景（例如批量离线推理）；
- 高度规则的 dense / semi-dense 计算：矩阵乘、卷积、规整的 MLP 等。

在这些场景中，TPU 的 systolic array、高带宽 ICI 和 XLA 图级优化可以最大化发挥效果。

### 6.2 TPU 的天然短板

与之对应，TPU 也有自己的天然短板 [2,4]：

- 非规则计算：高度稀疏、动态结构的图计算，可能难以高效映射到 MXU；
- 细粒度控制流：大量 `if` 分支、动态循环等，可能导致图形态复杂、编译困难；
- 小 batch / 低算力密度场景：难以填满 MXU 和整个 Pod 的算力，吞吐优势发挥不出来。

### 6.3 在 Cloud TPU 上写程序的实用 Checklist

如果开发者已经习惯在 CUDA 上做调优，可以把以下建议当作迁移到 Cloud TPU 时的对照参考 [2,3,4]。

在**模型与数据选择**方面，应优先选择 dense 或 semi-dense 计算密集型模型（如 Transformer、CNN 等）[2]，避免高度动态 shape，尽量在编译期就固定张量形状 [2,4]，如果必须使用动态 shape，也应尽量限制在少数几个模式。

在**张量尺寸与布局**方面，需确保关键维度（batch 或 feature）是 8 或 128 的倍数 [2]，注意 NHWC 或 NCHW 等布局对 MXU tile 利用率的影响，并在可能的情况下，从模型设计阶段就考虑这些对齐约束。

在**编译与调优**方面，利用 XLA profiling、TPU Profiler、XProf 等工具分析 tile 利用率和 padding 情况 [2,3]。优先通过「重构计算图」和「增加并行度」来提升吞吐，而不是抠单次 step 的 latency。对于热点子图，可以考虑在 HLO 或 StableHLO 层面做结构化改写，而不是尝试手写「TPU Kernel」（因为实际上并不存在）。

---

## 7. 总结：从 CUDA 工程师到 TPU 使用者，需要转变什么

回到本文一开始提出的两个核心问题：

- TPU 在硬件和软件层面「放弃了什么」，又「极致强化了什么」？
- 如果你会写 CUDA MatMul，那么 TPU 的 MatMul 是如何「被写出来」的？

从硬件层面看，TPU 放弃了线程级和 warp 级的精细调度权，通用复杂的 Cache 层级和显式可编程的 shared memory，以及面向广义 GPGPU 的通用负载适配能力。

TPU 强化了以 MXU 和 systolic array 为核心的矩阵乘流水线 [1,2]，由编译器显式管理的数据流和本地 scratchpad，以及 ICI 和 Pod 拓扑下的系统级矩阵吞吐能力 [1,2]。

从软件层面看，TPU 放弃了以 Kernel 为中心、以 launch 配置为关键调参手段的编程模型，以及在每个算子上由工程师决定「跑在哪些线程上、如何分 tile」的自由度。

TPU 强化了以 XLA 为核心的图级编译和跨算子融合能力 [3]，以 HLO 和 StableHLO 为中间层的跨框架可移植性 [3,5]，以及针对特定硬件（TPU Pod）的全局图调度和并行划分。

而对 CUDA 工程师来说，从 GPU 世界走向 TPU 世界，本质上是三次心智转变。

首先是从 **「写 kernel」→「写计算图」** 的转变：在 GPU 上，开发者写的是 Kernel，并为每个 Kernel 调参；在 TPU 上，开发者写的是计算图，让编译器为整体图做调度。

其次是从 **「控制硬件」→「相信编译器」** 的转变：在 GPU 上，开发者几乎能追踪每个线程的行为；在 TPU 上，必须假设编译器可以充分利用 MXU 和 Pod 拓扑，并通过 HLO 和 Profiling 来验证和反馈。

最后是从 **「单卡极致优化」→「系统级吞吐最大化」** 的转变：在 GPU 上，很多优化最终落在「单卡、单 kernel 如何跑满」；在 TPU 上，更应该关注「整个 Pod 上的矩阵流动是否顺畅」，包括 batch size、分布式并行策略和数据布局。

当你真正接受了这三次心智转变，就会发现：在 TPU 上，「写 MatMul」这件事情已经不再需要自己手写 Kernel；你的工作重心，变成了**如何把模型和数据设计成易于编译器在 systolic array 和 Pod 上高效运行的形式**。

---

## 8. 参考文献

[1] Google Cloud. "TPU architecture." Google Cloud Documentation. Accessed: Dec. 14, 2025. [Online]. Available: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm

[2] Google Cloud. "Introduction to Cloud TPU." Google Cloud Documentation. Accessed: Dec. 14, 2025. [Online]. Available: https://cloud.google.com/tpu/docs/intro-to-tpu

[3] OpenXLA. "XLA." OpenXLA Documentation. Accessed: Dec. 14, 2025. [Online]. Available: https://openxla.org/xla

[4] PyTorch. "PyTorch/XLA Overview." PyTorch/XLA Documentation. Accessed: Dec. 14, 2025. [Online]. Available: https://pytorch.org/xla/master/learn/xla-overview.html

[5] OpenXLA. "StableHLO." OpenXLA Documentation. Accessed: Dec. 14, 2025. [Online]. Available: https://openxla.org/stablehlo
