Skip to content

AlphaTensor

Goal: Discover faster (matrix multiplication) algorithms with RL.

Contribution: Discover SOTA algorithms for solving several math problems.

Concept

Overview of AlphaTensor, from DeepMind Blog.

Single-player game played by AlphaTensor, where the goal is to find a correct matrix multiplication algorithm. The state of the game is a cubic array of numbers (shown as grey for 0, blue for 1, and green for -1), representing the remaining work to be done.

  1. Formulate the matrix multiplication problem as a tensor decomposition problem.
    This is discussed in previous works such as (Lim 2021, p.67, Example 3.9).

    • Represent a matrix multiplication algorithm as a tensor.

      A n×n matrix multiplication algorithm (independent of the matrices being multiplied) can be represented by a fixed 3D (n2×n2×n2) tensor Tn (with entries in {0,1}) called the matrix multiplication tensor (or Strassen's tensor).

      Tn can be constructed trivially based on a simple O(n3) matrix multiplication algorithm, with m(i1)n+j=aibj and c(i1)n+j be the sum of n terms corresponding to the O(n3) algorithm.

    • Execute the matrix multiplication algorithm base on the decomposition of the tensor.

      Tn=r=1Ru(r)v(r)w(r), where u(r),v(r),w(r)Rn2 are vectors, and denotes the outer (tensor) product. The outer product of each triplet (u(r),v(r),w(r)) is a rank-one tensor, and the sum of these R rank-one tensors results in Tn with Rank(Tn)R. Solving the tensor decomposition problem results in an O(nlogn(R)) algorithm for multiplying arbitrary n×n matrices with R scalar multiplications:

      from Algorithm 1 of Fawzi et al., 2022.

      Example: The tensor and its decomposition of Strassen's algorithm (R=7):

      Matrix multiplication tensor and algorithms, from Fig.1 of Fawzi et al., 2022.

      Please note that the (column) vectors u(r),v(r),w(r) are stacked horizontally.

      • The elements uij in U are indicators for the presence of ai in mj.
      • The elements vij in V are indicators for the presence of bi in mj.
      • The elements wij in W are indicators for the presence of mj in ci.
  2. Formulate the tensor decomposition problem as a game named TensorGame.

    • State: S0=Tn,St=St1u(t)v(t)w(t)
    • Action: At=(u(t),v(t),w(t)) in canonical form, with coefficients in F (for example F={2,1,0,1,2})
    • Terminate Condition: when St=0 or t=T.
    • Reward: Rt=1 if tT, RT=r(ST), where r(ST) is an upper bound on the rank of the terminal tensor. The reward function may be modified to optimize practical runtime instead of rank.
  3. Solve the TensorGame with DRL and MCTS.

    • Network Architecture:
      • Torso: Modified transformer architecture.
      • Policy head: Transformer architecture.
      • Value head: 4-layer MLP that outputs estimation of certain quantiles.
    • Demonstration Data: Random sample factors {(u(t),v(t),w(t))}r=1R, sum the outer products to construct tensor D, and imitate these synthetic demonstrations.
    • Objective: Quantile regression for value network z, and minimize KL-divergence for policy network π (against synthetic demonstration or sample-based improved policy from MCTS).
    • Input Augmentation: Apply random change of basis on Tn before the game (i.e., define the matrix multiplication tensor with another basis). Apply random (signed permutation) change of basis in each new MCTS node.

      Overview of AlphaTensor, from Fig.2 of Fawzi et al., 2022.

    Please refer to the paper for figures of the architecture.

Results and Costs

  • Results
    • SOTA matrix multiplication algorithm for certain size of matrices.
    • SOTA algorithm for multiplying n×n skew-symmetric matrices.
    • Re-discover the Fourier basis.
  • Training Costs
    • A TPU v3 and TPU v4, takes a week to converge.

Official Resources

  • [Nature 2022] Discovering faster matrix multiplication algorithms with reinforcement learning [paper][blog][code] (citations:  98 78, as of 2023-04-28)

Community Resources

Comments