Skip to content
$$ \def\bm#1{\boldsymbol{#1}} %%%%% NEW MATH DEFINITIONS %%%%% % % Mark sections of captions for referring to divisions of figures % \newcommand{\figleft}{{\em (Left)}} % \newcommand{\figcenter}{{\em (Center)}} % \newcommand{\figright}{{\em (Right)}} % \newcommand{\figtop}{{\em (Top)}} % \newcommand{\figbottom}{{\em (Bottom)}} % \newcommand{\captiona}{{\em (a)}} % \newcommand{\captionb}{{\em (b)}} % \newcommand{\captionc}{{\em (c)}} % \newcommand{\captiond}{{\em (d)}} % Highlight a newly defined term \newcommand{\newterm}[1]{{\bf #1}} % % Figure reference, lower-case. % \def\figref#1{figure~\ref{#1}} % % Figure reference, capital. For start of sentence % \def\Figref#1{Figure~\ref{#1}} % \def\twofigref#1#2{figures \ref{#1} and \ref{#2}} % \def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}} % % Section reference, lower-case. % \def\secref#1{section~\ref{#1}} % % Section reference, capital. % \def\Secref#1{Section~\ref{#1}} % % Reference to two sections. % \def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}} % % Reference to three sections. % \def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}} % % Reference to an equation, lower-case. % \def\eqref#1{equation~\ref{#1}} % % Reference to an equation, upper case % \def\Eqref#1{Equation~\ref{#1}} % % A raw reference to an equation---avoid using if possible % \def\plaineqref#1{\ref{#1}} % % Reference to a chapter, lower-case. % \def\chapref#1{chapter~\ref{#1}} % % Reference to an equation, upper case. % \def\Chapref#1{Chapter~\ref{#1}} % % Reference to a range of chapters % \def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}} % % Reference to an algorithm, lower-case. % \def\algref#1{algorithm~\ref{#1}} % % Reference to an algorithm, upper case. % \def\Algref#1{Algorithm~\ref{#1}} % \def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}} % \def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}} % % Reference to a part, lower case % \def\partref#1{part~\ref{#1}} % % Reference to a part, upper case % \def\Partref#1{Part~\ref{#1}} % \def\twopartref#1#2{parts \ref{#1} and \ref{#2}} \def\ceil#1{\lceil #1 \rceil} \def\floor#1{\lfloor #1 \rfloor} \def\1{\bm{1}} \newcommand{\train}{\mathcal{D}} \newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}} \newcommand{\test}{\mathcal{D_{\mathrm{test}}}} \def\eps{{\epsilon}} % Random variables \def\reta{{\textnormal{$\eta$}}} \def\ra{{\textnormal{a}}} \def\rb{{\textnormal{b}}} \def\rc{{\textnormal{c}}} \def\rd{{\textnormal{d}}} \def\re{{\textnormal{e}}} \def\rf{{\textnormal{f}}} \def\rg{{\textnormal{g}}} \def\rh{{\textnormal{h}}} \def\ri{{\textnormal{i}}} \def\rj{{\textnormal{j}}} \def\rk{{\textnormal{k}}} \def\rl{{\textnormal{l}}} % rm is already a command, just don't name any random variables m \def\rn{{\textnormal{n}}} \def\ro{{\textnormal{o}}} \def\rp{{\textnormal{p}}} \def\rq{{\textnormal{q}}} \def\rr{{\textnormal{r}}} \def\rs{{\textnormal{s}}} \def\rt{{\textnormal{t}}} \def\ru{{\textnormal{u}}} \def\rv{{\textnormal{v}}} \def\rw{{\textnormal{w}}} \def\rx{{\textnormal{x}}} \def\ry{{\textnormal{y}}} \def\rz{{\textnormal{z}}} % Random vectors \def\rvepsilon{{\mathbf{\epsilon}}} \def\rvtheta{{\mathbf{\theta}}} \def\rva{{\mathbf{a}}} \def\rvb{{\mathbf{b}}} \def\rvc{{\mathbf{c}}} \def\rvd{{\mathbf{d}}} \def\rve{{\mathbf{e}}} \def\rvf{{\mathbf{f}}} \def\rvg{{\mathbf{g}}} \def\rvh{{\mathbf{h}}} \def\rvi{{\mathbf{i}}} \def\rvj{{\mathbf{j}}} \def\rvk{{\mathbf{k}}} \def\rvl{{\mathbf{l}}} \def\rvm{{\mathbf{m}}} \def\rvn{{\mathbf{n}}} \def\rvo{{\mathbf{o}}} \def\rvp{{\mathbf{p}}} \def\rvq{{\mathbf{q}}} \def\rvr{{\mathbf{r}}} \def\rvs{{\mathbf{s}}} \def\rvt{{\mathbf{t}}} \def\rvu{{\mathbf{u}}} \def\rvv{{\mathbf{v}}} \def\rvw{{\mathbf{w}}} \def\rvx{{\mathbf{x}}} \def\rvy{{\mathbf{y}}} \def\rvz{{\mathbf{z}}} % Elements of random vectors \def\erva{{\textnormal{a}}} \def\ervb{{\textnormal{b}}} \def\ervc{{\textnormal{c}}} \def\ervd{{\textnormal{d}}} \def\erve{{\textnormal{e}}} \def\ervf{{\textnormal{f}}} \def\ervg{{\textnormal{g}}} \def\ervh{{\textnormal{h}}} \def\ervi{{\textnormal{i}}} \def\ervj{{\textnormal{j}}} \def\ervk{{\textnormal{k}}} \def\ervl{{\textnormal{l}}} \def\ervm{{\textnormal{m}}} \def\ervn{{\textnormal{n}}} \def\ervo{{\textnormal{o}}} \def\ervp{{\textnormal{p}}} \def\ervq{{\textnormal{q}}} \def\ervr{{\textnormal{r}}} \def\ervs{{\textnormal{s}}} \def\ervt{{\textnormal{t}}} \def\ervu{{\textnormal{u}}} \def\ervv{{\textnormal{v}}} \def\ervw{{\textnormal{w}}} \def\ervx{{\textnormal{x}}} \def\ervy{{\textnormal{y}}} \def\ervz{{\textnormal{z}}} % Random matrices \def\rmA{{\mathbf{A}}} \def\rmB{{\mathbf{B}}} \def\rmC{{\mathbf{C}}} \def\rmD{{\mathbf{D}}} \def\rmE{{\mathbf{E}}} \def\rmF{{\mathbf{F}}} \def\rmG{{\mathbf{G}}} \def\rmH{{\mathbf{H}}} \def\rmI{{\mathbf{I}}} \def\rmJ{{\mathbf{J}}} \def\rmK{{\mathbf{K}}} \def\rmL{{\mathbf{L}}} \def\rmM{{\mathbf{M}}} \def\rmN{{\mathbf{N}}} \def\rmO{{\mathbf{O}}} \def\rmP{{\mathbf{P}}} \def\rmQ{{\mathbf{Q}}} \def\rmR{{\mathbf{R}}} \def\rmS{{\mathbf{S}}} \def\rmT{{\mathbf{T}}} \def\rmU{{\mathbf{U}}} \def\rmV{{\mathbf{V}}} \def\rmW{{\mathbf{W}}} \def\rmX{{\mathbf{X}}} \def\rmY{{\mathbf{Y}}} \def\rmZ{{\mathbf{Z}}} % Elements of random matrices \def\ermA{{\textnormal{A}}} \def\ermB{{\textnormal{B}}} \def\ermC{{\textnormal{C}}} \def\ermD{{\textnormal{D}}} \def\ermE{{\textnormal{E}}} \def\ermF{{\textnormal{F}}} \def\ermG{{\textnormal{G}}} \def\ermH{{\textnormal{H}}} \def\ermI{{\textnormal{I}}} \def\ermJ{{\textnormal{J}}} \def\ermK{{\textnormal{K}}} \def\ermL{{\textnormal{L}}} \def\ermM{{\textnormal{M}}} \def\ermN{{\textnormal{N}}} \def\ermO{{\textnormal{O}}} \def\ermP{{\textnormal{P}}} \def\ermQ{{\textnormal{Q}}} \def\ermR{{\textnormal{R}}} \def\ermS{{\textnormal{S}}} \def\ermT{{\textnormal{T}}} \def\ermU{{\textnormal{U}}} \def\ermV{{\textnormal{V}}} \def\ermW{{\textnormal{W}}} \def\ermX{{\textnormal{X}}} \def\ermY{{\textnormal{Y}}} \def\ermZ{{\textnormal{Z}}} % Vectors \def\vzero{{\bm{0}}} \def\vone{{\bm{1}}} \def\vmu{{\bm{\mu}}} \def\vtheta{{\bm{\theta}}} \def\va{{\bm{a}}} \def\vb{{\bm{b}}} \def\vc{{\bm{c}}} \def\vd{{\bm{d}}} \def\ve{{\bm{e}}} \def\vf{{\bm{f}}} \def\vg{{\bm{g}}} \def\vh{{\bm{h}}} \def\vi{{\bm{i}}} \def\vj{{\bm{j}}} \def\vk{{\bm{k}}} \def\vl{{\bm{l}}} \def\vm{{\bm{m}}} \def\vn{{\bm{n}}} \def\vo{{\bm{o}}} \def\vp{{\bm{p}}} \def\vq{{\bm{q}}} \def\vr{{\bm{r}}} \def\vs{{\bm{s}}} \def\vt{{\bm{t}}} \def\vu{{\bm{u}}} \def\vv{{\bm{v}}} \def\vw{{\bm{w}}} \def\vx{{\bm{x}}} \def\vy{{\bm{y}}} \def\vz{{\bm{z}}} % Elements of vectors \def\evalpha{{\alpha}} \def\evbeta{{\beta}} \def\evepsilon{{\epsilon}} \def\evlambda{{\lambda}} \def\evomega{{\omega}} \def\evmu{{\mu}} \def\evpsi{{\psi}} \def\evsigma{{\sigma}} \def\evtheta{{\theta}} \def\eva{{a}} \def\evb{{b}} \def\evc{{c}} \def\evd{{d}} \def\eve{{e}} \def\evf{{f}} \def\evg{{g}} \def\evh{{h}} \def\evi{{i}} \def\evj{{j}} \def\evk{{k}} \def\evl{{l}} \def\evm{{m}} \def\evn{{n}} \def\evo{{o}} \def\evp{{p}} \def\evq{{q}} \def\evr{{r}} \def\evs{{s}} \def\evt{{t}} \def\evu{{u}} \def\evv{{v}} \def\evw{{w}} \def\evx{{x}} \def\evy{{y}} \def\evz{{z}} % Matrix \def\mA{{\bm{A}}} \def\mB{{\bm{B}}} \def\mC{{\bm{C}}} \def\mD{{\bm{D}}} \def\mE{{\bm{E}}} \def\mF{{\bm{F}}} \def\mG{{\bm{G}}} \def\mH{{\bm{H}}} \def\mI{{\bm{I}}} \def\mJ{{\bm{J}}} \def\mK{{\bm{K}}} \def\mL{{\bm{L}}} \def\mM{{\bm{M}}} \def\mN{{\bm{N}}} \def\mO{{\bm{O}}} \def\mP{{\bm{P}}} \def\mQ{{\bm{Q}}} \def\mR{{\bm{R}}} \def\mS{{\bm{S}}} \def\mT{{\bm{T}}} \def\mU{{\bm{U}}} \def\mV{{\bm{V}}} \def\mW{{\bm{W}}} \def\mX{{\bm{X}}} \def\mY{{\bm{Y}}} \def\mZ{{\bm{Z}}} \def\mBeta{{\bm{\beta}}} \def\mPhi{{\bm{\Phi}}} \def\mLambda{{\bm{\Lambda}}} \def\mSigma{{\bm{\Sigma}}} % Tensor \newcommand{\tens}[1]{\mathsf{#1}} \def\tA{{\tens{A}}} \def\tB{{\tens{B}}} \def\tC{{\tens{C}}} \def\tD{{\tens{D}}} \def\tE{{\tens{E}}} \def\tF{{\tens{F}}} \def\tG{{\tens{G}}} \def\tH{{\tens{H}}} \def\tI{{\tens{I}}} \def\tJ{{\tens{J}}} \def\tK{{\tens{K}}} \def\tL{{\tens{L}}} \def\tM{{\tens{M}}} \def\tN{{\tens{N}}} \def\tO{{\tens{O}}} \def\tP{{\tens{P}}} \def\tQ{{\tens{Q}}} \def\tR{{\tens{R}}} \def\tS{{\tens{S}}} \def\tT{{\tens{T}}} \def\tU{{\tens{U}}} \def\tV{{\tens{V}}} \def\tW{{\tens{W}}} \def\tX{{\tens{X}}} \def\tY{{\tens{Y}}} \def\tZ{{\tens{Z}}} % Graph \def\gA{{\mathcal{A}}} \def\gB{{\mathcal{B}}} \def\gC{{\mathcal{C}}} \def\gD{{\mathcal{D}}} \def\gE{{\mathcal{E}}} \def\gF{{\mathcal{F}}} \def\gG{{\mathcal{G}}} \def\gH{{\mathcal{H}}} \def\gI{{\mathcal{I}}} \def\gJ{{\mathcal{J}}} \def\gK{{\mathcal{K}}} \def\gL{{\mathcal{L}}} \def\gM{{\mathcal{M}}} \def\gN{{\mathcal{N}}} \def\gO{{\mathcal{O}}} \def\gP{{\mathcal{P}}} \def\gQ{{\mathcal{Q}}} \def\gR{{\mathcal{R}}} \def\gS{{\mathcal{S}}} \def\gT{{\mathcal{T}}} \def\gU{{\mathcal{U}}} \def\gV{{\mathcal{V}}} \def\gW{{\mathcal{W}}} \def\gX{{\mathcal{X}}} \def\gY{{\mathcal{Y}}} \def\gZ{{\mathcal{Z}}} % Sets \def\sA{{\mathbb{A}}} \def\sB{{\mathbb{B}}} \def\sC{{\mathbb{C}}} \def\sD{{\mathbb{D}}} % Don't use a set called E, because this would be the same as our symbol % for expectation. \def\sF{{\mathbb{F}}} \def\sG{{\mathbb{G}}} \def\sH{{\mathbb{H}}} \def\sI{{\mathbb{I}}} \def\sJ{{\mathbb{J}}} \def\sK{{\mathbb{K}}} \def\sL{{\mathbb{L}}} \def\sM{{\mathbb{M}}} \def\sN{{\mathbb{N}}} \def\sO{{\mathbb{O}}} \def\sP{{\mathbb{P}}} \def\sQ{{\mathbb{Q}}} \def\sR{{\mathbb{R}}} \def\sS{{\mathbb{S}}} \def\sT{{\mathbb{T}}} \def\sU{{\mathbb{U}}} \def\sV{{\mathbb{V}}} \def\sW{{\mathbb{W}}} \def\sX{{\mathbb{X}}} \def\sY{{\mathbb{Y}}} \def\sZ{{\mathbb{Z}}} % Entries of a matrix \def\emLambda{{\Lambda}} \def\emA{{A}} \def\emB{{B}} \def\emC{{C}} \def\emD{{D}} \def\emE{{E}} \def\emF{{F}} \def\emG{{G}} \def\emH{{H}} \def\emI{{I}} \def\emJ{{J}} \def\emK{{K}} \def\emL{{L}} \def\emM{{M}} \def\emN{{N}} \def\emO{{O}} \def\emP{{P}} \def\emQ{{Q}} \def\emR{{R}} \def\emS{{S}} \def\emT{{T}} \def\emU{{U}} \def\emV{{V}} \def\emW{{W}} \def\emX{{X}} \def\emY{{Y}} \def\emZ{{Z}} \def\emSigma{{\Sigma}} % entries of a tensor % Same font as tensor, without \bm wrapper \newcommand{\etens}[1]{\mathsfit{#1}} \def\etLambda{{\etens{\Lambda}}} \def\etA{{\etens{A}}} \def\etB{{\etens{B}}} \def\etC{{\etens{C}}} \def\etD{{\etens{D}}} \def\etE{{\etens{E}}} \def\etF{{\etens{F}}} \def\etG{{\etens{G}}} \def\etH{{\etens{H}}} \def\etI{{\etens{I}}} \def\etJ{{\etens{J}}} \def\etK{{\etens{K}}} \def\etL{{\etens{L}}} \def\etM{{\etens{M}}} \def\etN{{\etens{N}}} \def\etO{{\etens{O}}} \def\etP{{\etens{P}}} \def\etQ{{\etens{Q}}} \def\etR{{\etens{R}}} \def\etS{{\etens{S}}} \def\etT{{\etens{T}}} \def\etU{{\etens{U}}} \def\etV{{\etens{V}}} \def\etW{{\etens{W}}} \def\etX{{\etens{X}}} \def\etY{{\etens{Y}}} \def\etZ{{\etens{Z}}} % The true underlying data generating distribution \newcommand{\pdata}{p_{\rm{data}}} % The empirical distribution defined by the training set \newcommand{\ptrain}{\hat{p}_{\rm{data}}} \newcommand{\Ptrain}{\hat{P}_{\rm{data}}} % The model distribution \newcommand{\pmodel}{p_{\rm{model}}} \newcommand{\Pmodel}{P_{\rm{model}}} \newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}} % Stochastic autoencoder distributions \newcommand{\pencode}{p_{\rm{encoder}}} \newcommand{\pdecode}{p_{\rm{decoder}}} \newcommand{\precons}{p_{\rm{reconstruct}}} \newcommand{\laplace}{\mathrm{Laplace}} % Laplace distribution \newcommand{\E}{\mathbb{E}} \newcommand{\Ls}{\mathcal{L}} \newcommand{\R}{\mathbb{R}} \newcommand{\emp}{\tilde{p}} \newcommand{\lr}{\alpha} \newcommand{\reg}{\lambda} \newcommand{\rect}{\mathrm{rectifier}} \newcommand{\softmax}{\mathrm{softmax}} \newcommand{\sigmoid}{\sigma} \newcommand{\softplus}{\zeta} \newcommand{\KL}{D_{\mathrm{KL}}} \newcommand{\Var}{\mathrm{Var}} \newcommand{\standarderror}{\mathrm{SE}} \newcommand{\Cov}{\mathrm{Cov}} % Wolfram Mathworld says $L^2$ is for function spaces and $\ell^2$ is for vectors % But then they seem to use $L^2$ for vectors throughout the site, and so does % wikipedia. \newcommand{\normlzero}{L^0} \newcommand{\normlone}{L^1} \newcommand{\normltwo}{L^2} \newcommand{\normlp}{L^p} \newcommand{\normmax}{L^\infty} \newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book. \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\Tr}{Tr} \let\ab\allowbreak $$

AlphaTensor

Goal: Discover faster (matrix multiplication) algorithms with RL.

Contribution: Discover SOTA algorithms for solving several math problems.

Concept

Overview of AlphaTensor, from DeepMind Blog.

Single-player game played by AlphaTensor, where the goal is to find a correct matrix multiplication algorithm. The state of the game is a cubic array of numbers (shown as grey for 0, blue for 1, and green for -1), representing the remaining work to be done.

  1. Formulate the matrix multiplication problem as a tensor decomposition problem.
    This is discussed in previous works such as (Lim 2021, p.67, Example 3.9).

    • Represent a matrix multiplication algorithm as a tensor.

      A \(n\times n\) matrix multiplication algorithm (independent of the matrices being multiplied) can be represented by a fixed 3D (\(n^2\times n^2\times n^2\)) tensor \(\tT_n\) (with entries in \(\{0,1\}\)) called the matrix multiplication tensor (or Strassen's tensor).

      \(\tT_n\) can be constructed trivially based on a simple \(O(n^3)\) matrix multiplication algorithm, with \(m_{(i-1)n+j}=a_ib_j\) and \(c_{(i-1)n+j}\) be the sum of \(n\) terms corresponding to the \(O(n^3)\) algorithm.

    • Execute the matrix multiplication algorithm base on the decomposition of the tensor.

      \(\tT_n=\sum^R_{r=1}\vu^{(r)}\otimes\vv^{(r)}\otimes\vw^{(r)}\), where \(\vu^{(r)},\vv^{(r)},\vw^{(r)}\in\R^{n^2}\) are vectors, and \(\otimes\) denotes the outer (tensor) product. The outer product of each triplet \((\vu^{(r)},\vv^{(r)},\vw^{(r)})\) is a rank-one tensor, and the sum of these \(R\) rank-one tensors results in \(\tT_n\) with \(\mathrm{Rank}(\tT_n)\le R\). Solving the tensor decomposition problem results in an \(O(n^{\log_n(R)})\) algorithm for multiplying arbitrary \(n\times n\) matrices with \(R\) scalar multiplications:

      from Algorithm 1 of Fawzi et al., 2022.

      Example: The tensor and its decomposition of Strassen's algorithm (\(R=7\)):

      Matrix multiplication tensor and algorithms, from Fig.1 of Fawzi et al., 2022.

      Please note that the (column) vectors \(\vu^{(r)},\vv^{(r)},\vw^{(r)}\) are stacked horizontally.

      • The elements \(u_{ij}\) in \(\mU\) are indicators for the presence of \(a_i\) in \(m_j\).
      • The elements \(v_{ij}\) in \(\mV\) are indicators for the presence of \(b_i\) in \(m_j\).
      • The elements \(w_{ij}\) in \(\mW\) are indicators for the presence of \(m_j\) in \(c_i\).
  2. Formulate the tensor decomposition problem as a game named TensorGame.

    • State: \(S_0=\tT_n, S_t=S_{t-1} - \vu^{(t)}\otimes\vv^{(t)}\otimes\vw^{(t)}\)
    • Action: \(A_t=(\vu^{(t)},\vv^{(t)},\vw^{(t)})\) in canonical form, with coefficients in \(\sF\) (for example \(\sF=\{-2,-1,0,1,2\}\))
    • Terminate Condition: when \(S_t=\bm{0}\) or \(t=T\).
    • Reward: \(R_t=-1\) if \(t\ne T\), \(R_T=-r(S_T)\), where \(r(S_T)\) is an upper bound on the rank of the terminal tensor. The reward function may be modified to optimize practical runtime instead of rank.
  3. Solve the TensorGame with DRL and MCTS.

    • Network Architecture:
      • Torso: Modified transformer architecture.
      • Policy head: Transformer architecture.
      • Value head: 4-layer MLP that outputs estimation of certain quantiles.
    • Demonstration Data: Random sample factors \(\{(\vu^{(t)},\vv^{(t)},\vw^{(t)})\}^R_{r=1}\), sum the outer products to construct tensor \(\tD\), and imitate these synthetic demonstrations.
    • Objective: Quantile regression for value network \(z\), and minimize KL-divergence for policy network \(\pi\) (against synthetic demonstration or sample-based improved policy from MCTS).
    • Input Augmentation: Apply random change of basis on \(\tT_n\) before the game (i.e., define the matrix multiplication tensor with another basis). Apply random (signed permutation) change of basis in each new MCTS node.

      Overview of AlphaTensor, from Fig.2 of Fawzi et al., 2022.

    Please refer to the paper for figures of the architecture.

Results and Costs

  • Results
    • SOTA matrix multiplication algorithm for certain size of matrices.
    • SOTA algorithm for multiplying \(n\times n\) skew-symmetric matrices.
    • Re-discover the Fourier basis.
  • Training Costs
    • A TPU v3 and TPU v4, takes a week to converge.

Official Resources

  • [Nature 2022] Discovering faster matrix multiplication algorithms with reinforcement learning [paper][blog][code] (citations:  98 78, as of 2023-04-28)

Community Resources

Comments