AlphaTensor
Goal: Discover faster (matrix multiplication) algorithms with RL.
Contribution: Discover SOTA algorithms for solving several math problems.
Concept
Overview of AlphaTensor, from DeepMind Blog.
Single-player game played by AlphaTensor, where the goal is to find a correct matrix multiplication algorithm. The state of the game is a cubic array of numbers (shown as grey for 0, blue for 1, and green for -1), representing the remaining work to be done.
-
Formulate the matrix multiplication problem as a tensor decomposition problem.
This is discussed in previous works such as (Lim 2021, p.67, Example 3.9).-
Represent a matrix multiplication algorithm as a tensor.
A
matrix multiplication algorithm (independent of the matrices being multiplied) can be represented by a fixed 3D ( ) tensor (with entries in ) called the matrix multiplication tensor (or Strassen's tensor). can be constructed trivially based on a simple matrix multiplication algorithm, with and be the sum of terms corresponding to the algorithm.
-
Execute the matrix multiplication algorithm base on the decomposition of the tensor.
, where are vectors, and denotes the outer (tensor) product. The outer product of each triplet is a rank-one tensor, and the sum of these rank-one tensors results in with . Solving the tensor decomposition problem results in an algorithm for multiplying arbitrary matrices with scalar multiplications:from Algorithm 1 of Fawzi et al., 2022.
Example: The tensor and its decomposition of Strassen's algorithm (
):Matrix multiplication tensor and algorithms, from Fig.1 of Fawzi et al., 2022.
Please note that the (column) vectors
are stacked horizontally.- The elements
in are indicators for the presence of in . - The elements
in are indicators for the presence of in . - The elements
in are indicators for the presence of in .
- The elements
-
-
Formulate the tensor decomposition problem as a game named TensorGame.
- State:
- Action:
in canonical form, with coefficients in (for example ) - Terminate Condition: when
or . - Reward:
if , , where is an upper bound on the rank of the terminal tensor. The reward function may be modified to optimize practical runtime instead of rank.
- State:
-
Solve the TensorGame with DRL and MCTS.
- Network Architecture:
- Torso: Modified transformer architecture.
- Policy head: Transformer architecture.
- Value head: 4-layer MLP that outputs estimation of certain quantiles.
- Demonstration Data: Random sample factors
, sum the outer products to construct tensor , and imitate these synthetic demonstrations. - Objective: Quantile regression for value network
, and minimize KL-divergence for policy network (against synthetic demonstration or sample-based improved policy from MCTS). -
Input Augmentation: Apply random change of basis on
before the game (i.e., define the matrix multiplication tensor with another basis). Apply random (signed permutation) change of basis in each new MCTS node.Overview of AlphaTensor, from Fig.2 of Fawzi et al., 2022.
Please refer to the paper for figures of the architecture.
- Network Architecture:
Results and Costs
- Results
- SOTA matrix multiplication algorithm for certain size of matrices.
- SOTA algorithm for multiplying
skew-symmetric matrices. - Re-discover the Fourier basis.
- Training Costs
- A TPU v3 and TPU v4, takes a week to converge.
Official Resources
- [Nature 2022] Discovering faster matrix multiplication algorithms with reinforcement learning [paper][blog][code] (citations: 98, 78, as of 2023-04-28)