Learning With Matrices
Recently I was talking to a new software engineer who is trying to learn Rust. They felt like they had a good grasp of the basics and wanted to write a small application to test their skills. My suggestion was that they code up a matrix struct (and later, a trait) that encapsulates the behaviour expected of a general matrix. This seemed like quite a fun task, so I had a go too.
My attempt
It didn’t take me long to put together a very basic matrix struct, and a handful of matrix functions. It looked something like this:


There is nothing complicated here.
As expected, we have an integer to hold the number of rows, a second integer for the number of columns, and blob of data (a Vec<S>
) in this case to hold the elements of the matrix.
So far so good, right?
With everything seemingly fine, I went on to put together some basic functions, starting with a new
constructor function:


Nothing complicated here, just taking the data that one would use to create a matrix and returning the necessary struct, including a little assertion that all the sizes match up correctly. I didn’t originally implement these, but here are some alternative construction functions that one might have defined:


Ok, enough playing around with the easy stuff, let’s jump ahead to matrix multiplication.
(Any time we talk about implementating a matrix, what we’re really talking about is matrix multiplication.)
Now matrix multiplication has some requirements on the scalar type, that I’ve denoted S
.
Namely, we must be able to multiply two scalars (that is, *
) together to get another scalar, and we must be able to add one scalar to another (inplace, which is +=
).
These properties are encpasulated in the standard library as the traits Mul
and AddAssign
.
Once we constrain the scalar type S
with these two traits, and the Default
trait we used above for simplicity, we can implement matrix multiplication like this:


This looks like a mess of code, so a little explanation is in order.
On line 6 we check that the left hand matrix (&self
) has the same number of columns as the right hand matrix (&other
) has rows.
This is the condition required for matrix multiplication to be valid.
Next we create a new matrix result
that will hold the product.
This matrix has the same number of rows as the left hand matrix, and the same number of columns as the right hand matrix.
Now we have to iterate through each of the rows and columns, and fill the corresponding value in the output.
In the ijth position, the value is the “dot product” of the ith row of the left and jth column of the right.
This “dot product” is the innermost loop of the computation.
Notice that we use the Default
trait when we create the zero matrix, and the AddAssign
and Mul
traits via the +=
and *
operators.
To save a bit of space, I’ve also used slicing to extract a relevant part of the left and right hand matrices, and simplify the index calculations in the process.
Here we can clearly see the principal reason that matrix multiplication can be quite slow. Accesses to the right hand side array are spaced far apart in the innermost loop. This makes it a little tricky to effeciently read the data in to the computation. Optimised libraries like the Intel Math Kernel Library (MKL) use a variety of techniques to make the most use of the data, and the effects are dramatic. More on this later, maybe.
Back to the main story
What I’ve just described is how I set about defining a Matrix struct in Rust, and, of course, say the other software engineer set things out a little differently. Actually, the difference between the two approaches is very interesting, and exposes various issues around software design, the Rust langauge, and computational efficiency. Their struct looked like this:


The first thing to note is that we both use generics in different ways.
In my approach, I made the struct generic over the scalar type S
.
In this approach, the struct is generic over the shape of the matrix, and the scalar type is fixed as double (f64).
This approach benefits from the fact that indexing into the matrix data is substantially easier (in some ways) since one simply uses two sets of brackets rather than using a strided access as in my approach.
Moreover, since the matrix dimensions are baked into the struct itself, we don’t need to do runtime checks to see if the matrices are compatible.
Such checks can be done by the Rust type system.
To illustrate the point, here’s how you might implement matrix multiplication for this struct.


This implementation is very clean, and there is no runtime cost of checking matrix shapes. Moreover, since th dimensions are known at compile time, the compiler can unleash the whole armory of optimisations on these loops to make them run as fast as possible.
It’s actually quite interesting to note that const generics  nontype, compiletime constant integer generic terms  are a relatively new addition to the Rust toolkit. I started learning Rust in 2018, long before this addition, to the Rust language. However, there is another reason I didn’t go with this approach. Type system based constraints certainly simplify runtime code, and make it harder to make mistakes, but they suffer from one big problem: you (the programmer) have to compute the type of the answer every time, and correctly record that in function and trait definitions. In this case, a simple function definition, computing the result type is easy. However, it is easy to come up with situations where this is not so easy. Consider the following:


This looks fine, and in C++ the equivalent template code would work just fine. Unfortunately, Rust’s const generics are not quite as mature (at a lot more safe) than C++ templates, so we actually get the following error:


(For future reference, this is using Rust 1.65.) In all fairness, this is not a standard matrix operation but it does illustrate the problem rather nicely.
Generalising to multidimensional arrays
A limitation of both our approaches is that neither can be generalised to multidimensional arrays. (Such multidimensional arrays sometimes called tensors, but I find this terminology rather unsatisfying.) Both our approaches hard code the number of dimensions as 2. Common implementations of multidimensional arrays—like Python’s numpy—make use of an list of sizes for each of the dimensions of the array. This list is usually called the shape of the array. Here’s what a multidimensional array might look like:


The struct is very similar in design to my implementation of the matrix.
We simply replace the n_rows
and n_cols
with a vector of usize
describing the shape of the array.
In this case, we can’t use const generics to describe the shape unless the number of dimensions is fixed in advance.
Something that is missing in both matrix implementations and the the array definition above is an explicit mention of the “layout” of the data. For matrices, there are two clear orders in which the data can be stored: row major, where the data from each row is stored contiguously; and column order, where the data from each column is stored contiguously. Row major ordering is the standard in C whilst column major is the standard in Fortran, so these are sometimes referred to as “C order” and “F order”, respectively. For multidimensional arrays, the same ordering choice is available. We can add this now in the form of a enum:


The ordering of the data becomes important when performing various algorithms for linear algebra, such as solving a system of linear equations.
To round things out, let’s have a look at something that causes a few problems for an ergonomic multidimensional array type in Rust: the arithmetic traits.
To enable the use of arithmetic operations such as +
, one must implement the corresponding trait, Add
for the types involved.
For simple types this works just fine, but for more complex types this becomes rather problematic.
For instance, one might implement addition for multidimensional arrays as follows:


Looks reasonable, but can you see the problem?
The problem lies in the signature of the add
function: it takes self
and other
by (moved) value.
If the types were trivial  such as f64
 then this is not a problem at all since copies are extremely cheap (and are essentially unavoidable as these values are moved into CPU registers, if you’re thinking on this level).
However, our types are not trivial; they own (potentially large amounts) of data.
Now Rust has move semantics by default, so actually this doesn’t “cost” anything in and of itself, since no data will actually be copied.
However, both the left hand and right hand NDArray
objects will be consumed by this operation, and we won’t be able to reuse them later.
Ideally, we’d make use of shared references to these large objects, since this way we can avoid copying or moving altogether.
The code looks pretty similar, with the addition of &
in the relevant places and a couple of lifetime parameters.
(I’ll leave this to the reader’s imagination to save a little space.)
Unfortunately, we now have a different problem:


Notice that every time we use the +
operation with NDArrays we have to insert &
to indicate that we’re applying the arithmetic operation to the shared references.
To me, this is one of the greatest problems with defining such structures in Rust at the moment.
In C++, an object or other kind of reference silently decays to a const reference, so this is simply part of the operator lookup routine.
However, in Rust this needs to be explicit (to some extent).
I believe there is a way to address this, at least superficially, using the Borrow
trait and generics, but this seems like huge overkill for such a simple problem.
Anyway, we shall leave this experiment here for now.