HomeArtificial IntelligenceMassive-Scale Matrix Factorization on TPUs

Massive-Scale Matrix Factorization on TPUs

Matrix factorization is likely one of the oldest, but nonetheless broadly used, strategies for studying find out how to suggest gadgets equivalent to songs or films from consumer rankings. In its fundamental type, it approximates a big, sparse (i.e., principally empty) matrix of user-item interactions with a product of two smaller, denser matrices representing realized merchandise and consumer options. These dense matrices, in flip, can be utilized to suggest gadgets to a consumer with which they have not interacted earlier than.

Regardless of its algorithmic simplicity, matrix factorization can nonetheless obtain aggressive efficiency in recommender benchmarks. Alternating least squares (ALS), and particularly its implicit variation, is a basic algorithm to study the parameters of matrix factorization. ALS is understood for its excessive effectivity as a result of it scales linearly within the variety of rows, columns and non-zeros. Therefore, this algorithm could be very nicely suited to large-scale challenges. However, for very massive real-world matrix factorization datasets, a single machine implementation wouldn’t suffice, and so, it will require a big distributed system. A lot of the distributed implementations of matrix factorization that make use of ALS leverage off-the-shelf CPU gadgets, and rightfully so, as a result of inherently sparse nature of the issue (the enter matrix is generally empty).

However, latest success of deep studying, which has exhibited rising computational capability, has spurred a brand new wave of analysis and progress on {hardware} accelerators equivalent to Tensor Processing Models (TPUs). TPUs afford area particular {hardware} speedups, particularly to be used instances like deep studying, which includes a lot of dense matrix multiplications. Particularly, they permit vital speedups for conventional data-parallel workloads, equivalent to coaching fashions with Stochastic Gradient Descent (SGD) in SPMD (single program a number of information) style. The SPMD method has gained reputation in computations like coaching neural networks with gradient descent algorithms, and can be utilized for each data-parallel and model-parallel computations, the place we distribute parameters of the mannequin throughout obtainable gadgets. However, whereas TPUs have been enormously engaging for strategies based mostly on SGD, it’s not instantly clear if a excessive efficiency implementation of ALS, which requires a lot of distributed sparse matrix multiplies, might be developed for a large-scale cluster of TPU gadgets.

In “ALX: Massive Scale Matrix Factorization on TPUs”, we discover a distributed ALS design that makes environment friendly use of the TPU structure and may scale nicely to matrix factorization issues of the order of billions of rows and columns by scaling the variety of obtainable TPU cores. The method we suggest leverages a mix of mannequin and information parallelism, the place every TPU core each shops a portion of the embedding desk and trains over a novel slice of information, grouped in mini-batches. To be able to spur future analysis on large-scale matrix factorization strategies and for instance the scalability properties of our personal implementation, we additionally constructed and launched an actual world net hyperlink prediction dataset referred to as WebGraph.

The determine reveals the movement of information and computation by the ALX framework on TPU gadgets. Much like SGD-based coaching procedures, every TPU core performs similar computation for its personal batch of information in SPMD style, which permits for synchronous computation in parallel on a number of TPU cores. Every TPU begins with gathering all of the related merchandise embeddings within the Sharded Collect stage. These materialized embeddings are used to resolve for consumer embeddings that are scattered to the related shard of the embedding desk within the Sharded Scatter stage.

Dense Batching for Improved Effectivity
We designed ALX particularly for TPUs, exploiting distinctive properties of TPU structure whereas overcoming a couple of fascinating limitations. As an example, every TPU core has restricted reminiscence and restricts all tensors to have a static form, however every instance in a mini-batch can have a wildly various variety of gadgets (i.e., inputs might be lengthy and sparse). To resolve this, we break exceedingly lengthy examples into a number of smaller examples of the identical form, a course of referred to as dense batching. Extra particulars about dense batching might be present in our paper.

Illustrating instance of how sparse batches are densified to extend effectivity on TPUs.

Uniform Sharding of Embedding Tables
With the batching drawback solved, we subsequent need to factorize a sparse matrix into two dense embedding matrices (e.g., consumer and merchandise embeddings) such that the ensuing dot product of embeddings approximate the unique sparse matrix — this helps us infer predictions for all the positions from the unique matrix, together with people who have been empty, which can be utilized to suggest gadgets with which customers haven’t interacted. Each the ensuing embedding tables (W and H within the determine beneath) can doubtlessly be too massive to slot in a single TPU core, thus requiring a distributed coaching setup for many large-scale use instances.

Most earlier makes an attempt of distributed matrix factorization use a parameter server structure the place the mannequin parameters are saved on extremely obtainable servers, and the coaching information is processed in parallel by employees which can be solely answerable for the training process. In our case, since every TPU core has similar compute and reminiscence, it is wasteful to solely use both reminiscence for storing mannequin parameters or compute for coaching. Thus, we designed our system such that every core is used to do each.

Illustrative instance of factorizing a sparse matrix Y into two dense embedding matrices W and H.

In ALX, we uniformly divide each embedding tables, thus totally exploiting each the scale of distributed reminiscence obtainable and the devoted low-latency interconnects between TPUs. That is extremely environment friendly for very massive embedding tables and leads to good efficiency for distributed collect and scatter operations.

Uniform sharding of each embedding tables (W and H) throughout TPU cores (in blue).

Since potential functions might contain very massive information units, scalability is doubtlessly an vital alternative for development in matrix factorization. To that finish, we additionally launch a big real-world net hyperlink prediction dataset referred to as WebGraph. This dataset might be simply modeled as a matrix factorization drawback the place rows and columns are supply and vacation spot hyperlinks, respectively, and the duty is to foretell vacation spot hyperlinks from every supply hyperlink. We use WebGraph for instance the scaling properties of ALX.

The WebGraph dataset was generated from a single crawl carried out by CommonCrawl in 2021 the place we strip every thing and preserve solely the link->outlinks information. Because the efficiency of a factorization methodology is determined by the properties of the underlying graph, we created six variations of WebGraph, every various within the sparsity sample and locale, to review how nicely ALS performs on every.

  • To check locale-specific graphs, we filter based mostly on two prime degree domains: ‘de’ and ‘in’, every producing a graph with an order of magnitude fewer nodes.
  • These graphs can nonetheless have arbitrary sparsity patterns and dangling hyperlinks. Thus we additional filter the nodes in every graph to have a minimal of both 10 or 50 inlinks and outlinks.

For simple entry, now we have made these obtainable as a Tensorflow Dataset bundle. For reference, the largest model, WebGraph-sparse, has greater than 365M nodes and 30B edges. We create and publish each coaching and testing splits for analysis functions.

We rigorously tune the system and high quality parameters of ALX. Primarily based on our observations associated to precision and selection of linear solvers. ​​We noticed that by rigorously deciding on the precision for storage of the embedding tables (bfloat16) and for the enter to the linear solvers (float32), we have been capable of halve the reminiscence required for the embeddings whereas nonetheless avoiding issues arising from decrease precision values through the remedy stage. For our linear solvers, we chosen conjugate gradients, which we discovered to be the quickest throughout the board on TPUs. We use embeddings of dimension 128 and prepare the mannequin for 16 epochs. In our expertise, hyperparameter tuning over each norm penalty (λ) and unobserved weight (α) has been indispensable for good recall metrics as proven within the desk beneath.

Outcomes obtained by working ALX on all variations of WebGraph dataset. Recall values of 1.0 denote good recall.

Scaling Evaluation
Because the enter information are processed in parallel throughout TPU cores, growing the variety of cores decreases coaching time, ideally in a linear style. However on the similar time, a bigger variety of cores requires extra community communication (as a result of sharded embedding tables). Due to high-speed interconnects, this overhead might be negligible for a small variety of cores, however because the variety of cores will increase, the overhead ultimately slows down the perfect linear scaling.

To be able to affirm our speculation, we analyze scaling properties of the 4 greatest WebGraph variants when it comes to coaching time as we enhance the variety of obtainable TPU cores. As proven beneath, even empirically, we do observe the anticipated linear lower in coaching time as much as a candy spot, after which the community overhead slows the decline.

Scaling evaluation of working time because the variety of TPU cores are elevated. Every determine plots the time taken to coach for one epoch in seconds.

For simple entry and reproducibility, the ALX code is open-sourced and might be simply run on Google Cloud. The truth is, we illustrate {that a} sparse matrix like WebGraph-dense of dimension 135M x 135M (with 22B edges) might be factorized in a colab related to eight TPU cores in lower than a day. We’ve designed the ALX framework with scalability in thoughts. With 256 TPU cores, one epoch of the biggest WebGraph variant, WebGraph-sparse (365M x 365M sparse matrix) takes round 20 minutes to complete (5.5 hours for the entire coaching run). The ultimate mannequin has round 100B parameters. We hope that the ALX and WebGraph might be helpful to each researchers and practitioners working in these fields. The code for ALX might be discovered right here on github!

The core staff contains Steffen Rendle, Walid Krichene and Li Zhang. We thank many Google colleagues for serving to at numerous levels of this undertaking. Particularly, we’re grateful to the JAX staff for quite a few discussions, particularly James Bradbury and Skye Wanderman-Milne; Blake Hechtman for assist with XLA and Rasmus Larsen for helpful discussions about efficiency of linear solvers on TPUs. Lastly, we’re additionally grateful to Nicolas Mayoraz, John Anderson, and Fernando Pereira for offering helpful suggestions.



Please enter your comment!
Please enter your name here

Most Popular

Recent Comments