WIP notes on machine learning, reader beware, mostly for self study but maybe there are useful to you too


Often we want to find some function f which fits some property, like if we have a dataset of inputs and outputs xi and yi, we’d like to find an f such that f(xi) = yi. Finding such a function isn’t normally easy! Almost always, we have to settle for an approximation such that f(xi) “is close to” yi. There are many ways we can choose to measure what we mean by “is close to” and many ways we can measure how well we do on the entire task. These make up the loss function: loss(f(xi), yi) -> scalar and we commonly seek to minimize the mean loss across our dataset. Mean isn’t the only possible choice, you might want to minimize the worst case loss, or weight certain inputs as more important to do better on etc.

What are the inputs and outputs of f? f : scalar -> scalar just isn’t that interesting! So we quickly move to multiple dimensions and think in terms of linear algebra. We might have f : vector -> scalar or f : vector -> vector or f : matrix -> scalar etc. Ex. image classification of “is dog in image?” could be f : (1024,1024,3) -> scalar where the input is 1024 pixel image with 3 color channels at each pixel RedGreenBlue. (1024,1024,3) is called its shape. We can equivalently think of the image as (3,1024,1024) as 3 1024x1024 pixel images, one for each color. One natural choise of scalar output for the dog classifier would be positive values for dog and negative values for not dog and the distance from 0 could represent the confidence. 100 is definitely not dog, -0.1 is an unsure not dog.

Some data like language text needs to be converted into numbers so that we can compute on them. For basic English text, we could assign the integer numbers 1-26 for lowercase letters, 27 to a space, etc. There is a ubiquitous version of this already called ascii used in computers (the next most common is unicode in utf-8 likely, but there are tons of encodings). If we wanted to predict the next letter of some text of length T, our f : (T) -> scalar would take in a vector of the corresponding integers to the text we want to predict and output a scalar that is its prediction. In general, we work with real numbers (well actually floating point) instead of integers, so how should we interpret an output of 22.35? Do we round it, etc. Instead, we might choose f : (T) -> (V) where V is our vocabulary size and each entry in (V) would mean how much the model things the output will be the ith letter. A perfect prediction might be represented by a one-hot vector where the correct ith position would be 1 and 0 everywhere else. You can choose whatever output format you want, but one nice thing for a task like this of discrete choices is having the model output a probability distribution or probability vector where the probability in the ith position is the probability of the ith letter being the next letter. The sum of the vector should add up to 1 since it is a probability vector. One trick is that we can turn any vector (take the output of f) into vector with sum 1 by dividing by its sum, but there might also be negative entries so we also need to somehow deal with those. To do that, we can take e^f(x) since negative values get mapped to small values and positive to positive. So overall, we can do normalize(exp(f(x))) which turns whatever output we get from f into a probability vector, nice! normalize(exp(…)) is called softmax because it accentuates/amplifies larger values and diminishes/filters smaller and negative values. People call the raw output f(x) logits but I’m not entirely sure that is true to the meaning of logit (log odds or log(p/(1-p))).

Extending the idea that integers suck for modeling, instead of one integer per vocab character, let’s assign one real valued vector of size C (channels) to each character. This gives us a matrix of shape (V,C) and now if we are doing next character prediction we want f : (T,C) -> (V) . To use it on an input, we take our T characters, lookup each character in our dictionary (commonly called embedding matrix) and form a matrix (T,C). We can call that a function embed : (T) -> (T,C). Then, we use the softmax trick to turn the output of f into a probability vector over our vocabulary. The whole thing looks like softmax(f(embed(xi))). But wait, our known output yi is a single character/integer still, what does it mean for an integer to be “is close to” a probability vector; in other words what is our loss? First, let’s think of our yi as a one hot vector again which is also a probability vector. Now we can ask about what loss functions can we use to compare probability vectors. Commonly this is cross entropy which is -sum(dot(p, log(q))); it is the negative p-weighted sum of log(q). Here p is yi one hot pvector and q is our output pvector. Since yi is one-hot, the only nonzero term is at yi (the character), so the loss is -log(q[yi]). When computing this, pytorch for example wants the pre-softmax output because we don’t actually care about the values at any index besides yi and so we can skip the division everywhere except at yi. So then it is like x’ = exp(f(x)) lossi = -log(x’[yi] / sum(x’))

Okay, so we have inputs, outputs, and a way to measure how good the function is at doing what we want; how do we actually represent a good f? Let’s consider a simple f for our next character prediction task (after embedding) so f : (T,C) -> (V) . Linear algebra tells us that matrices represent/are linear functions of shape (O,I) where O is output dim and I is input dim. We can apply it on an input vector x of size (I) with a left multiply Ax has shape (O,I)(I,1) -> (O,1). We can freely insert dimensions of shape 1 to make things work and (I,1) is I rows of 1 column so we call it a column vector. Generally matrix multiply has shape (M,N)(N,O) -> (M,O) and can be thought of as running the function from the left matrix on each column in the right matrix. What does the left matrix do? Remember that each entry in the output[i, j] is the dot product of the ith row and jth column (for AB = C that is C[i, j] = dot(row(A, i), col(B, j))). The dot product is a multiply component-wise and accumulate with sum: dot(x:vector, y:vector) -> scalar = sum(mul(xi, yi)) = l(x)l(y)cos(angle(x, y)). The second interpretation there is that the dot is also the product of the vector lengths (sqrt(sum(square(x))) = n-d hypotenuse) with the cosine of the angle between the vectors. When we have unit length vectors ie l(x) = 1 and l(y) = 1, the dot(x, y) = cos(angle(x, y)) so we get a scalar in [-1, 1] based on the angle between the vectors. We can interpret vectors as being similar/close if their angle is small, small angle makes cos close to 1. Orthogonal/perpindicular vectors will be at 0 (cos(pi/2) = 0). Vectors pointing in opposite directions will have angle 180 (cos(pi) = -1). So dot product gives a value in [-1, 1] and tells us a measure of how much the vectors are pointing in similar directions. Back to our matrix multiply Ax (O,I)(I,1) -> (O,1), A has O rows of I-vectors and each row i of the output is dot(row(A, i), x). So the matrix multiply measures the similarity of each of its rows with the input vector x. Our output vector (O,1) at each ith coordinate tells us how similar the input vector was to the ith row of A. Thinking of our f as a model of our input, we can imagine A as being a list of features we’re interested in and Ax will check how much the input matches each of those features. In the general case AB = C (M,N)(N,O) -> (M,O), we are computing the similarity of each column of B against each feature in A. This is why we get shape (M,O) out because there are M rows of A and O columns of B. For next character prediciton f : (T,C) -> (V) we could conjure up two matrices P and Q and do (V,T)(T,C)(C,1) -> (V,1). Oh yeah, and matrix multiply is associative (but not commutative).

Okay, so we have a representation of f as P (V,T), Q (C,1), and E (V,C) (embedding dictionary), what numbers do we put in these matrices? Random! Yes we’ll be lazy and just randomly initialize them. We can now run our function f on some input xi and measure the loss against yi. It will almost surely be a big loss, this f is bad, this is okay. Because now we can try to find a slightly better f by computing the gradient (grad : scalar -> vector) of the loss function at our input xi grad(loss(f(xi), yi)). The gradient returns a vector of derivatives, one for each parameter in f. The parameters (sometimes called theta) in f are the entries in its matrices. These derivatives are the partial derivative of the loss function with respect to that parameter at the value xi. The derivative tells us the slope of the linear approximation of our loss function (wrt each parameter) at the location x. The sign of the derivative tells us how increasing/decreasing that parameter will increase/decrease the loss. Picture a parabola x^2, deriv is x, negative x < 0 and positive x > 0, negative deriv means decreasing x will increase x^2 whereas positive deriv means increasing x will increase x^2. Since we want to minimize the loss, we update the parameters with the negative of the derivative. The magnitude of the derivative says how much we expect the loss to change for a given amount of change in that parameter. Or in a relative sense, parameters with a large relative size are currently affecting the loss the most. We can update our parameters in f by adding the gradient scaled by a step size theta’ = theta + mul(step, grad(loss(f(xi, theta), yi))) and I’m writing f(xi, theta) to make it clear now that f is parameterized by theta. This is a basic SGD stochastic gradient descent. There are other optimizers. Our f is really a giant space of possible functions, one for each possible configuration/assignment of parameters theta. Each architecture of f will be its own space of possible functions. The architecture of f is the design of how we choose to multiply matrices and what other operations we utilize.

Where do these gradients come from? The computer calculates them for us! Remember all the rules of derivatives? You can sometimes use those to write down a closed form expression for the derivative of your chosen loss function with your chosen architecture and you can even do this by hand, but we have a great tool called auto differentiation which takes care of this and even works when there isn’t a closed form expression. It does this by storing some information on the side for each operation that your model computes in the forward pass ie when it is running f(xi), then uses that information in a backward pass to compute the gradient at each parameter. So when we are in training mode, we typically require (at least) 2N memory for an N parameter model. With other optimizers, we may need more, like Adam (adaptive moment estimation) which keeps 2 persistent values per parameter in order to better estimate the gradient. It does this with an expwma. Q what is the second “moment” I kinda understand the first but not the second. We run this iterative process of testing how good the current parameters’ loss is and updating them over and over again. The loss will not be the same on every single input in our training set, so it is important to train over all of them. One pass over the training set is called an epoch. One other note is that to make things go faster, execution is batched: we run multiple xi at the same time for a given parameter configuration. We then get a loss for each xi and an accompanying gradient for each xi which we can average together to inform our parameter update. Now we can observe that the training process is very dynamic: the order of training inputs and grouping of training inputs into batches will affect the parameter updates at each step and thus what parameters we will end up with at the end of 1 or more epochs. Given that models with well over 1M parameters are common, we will almost surely never find “the one true best” parameter configuration for a given dataset and architecture. Remember that gradient descent is a local update procedure, not a global one; we do not update the parameters in the direction of the global optimal loss minimizer, we update in the direction of the local loss minimizer. Optimizer schemes like Adam were created to avoid problems like getting stuck in a local minimum, where if we are only taking local steps to minimize the loss, we risk getting to a valley of low loss and never leaving, meanwhile that valley is actually at a pretty high loss compared to the rest of the loss landscape. Other schemes like weight decay unilaterally shrink the weights of all parameters (multiply each parameter by a scalar in [0,1]) to avoid overfitting, where sure we might find some configuration of parameters that perfectly fit our training data, but do poorly on new unseen (validation/production) data. Q: is there a selective version of weight decay that multiplies each parameter with its own scalar?

We now need to talk about nonlinearity, so first, what is linearity? Linearity is a property of a function which satsifies f(a + b) = f(a) + f(b). Multiplication is linear because for example c (think of c as the function which multiplies by c) c(a + b) = ca + cb. Aka distributivity over addition. The dot product is linear because (a b)((c d) + (e f)) = (a b)((c + e) (d + f)) = a(c + e) + b(d + f) = ac + ae + bd + bf = (a b)(c d) + (a b)(e f). This means matrix multiplication is linear as well. This actually presents a problem: the basis for organizing our model into vectors and doing matrix multiplies means we could only ever model linear approximations of our data. But lots (most?) of the interesting applications of machine learning are exactly to those problems which are nonlinear. So we have to introduce nonlinearity into our architecture to give them more modeling power. Also, as we build bigger architectures under the assumption that more parameters gives us more modeling power / ability to fit more complicated datasets, we need the nonlinearity to prevent our architecture from collapsing. For example, say you hear that more parameters is more better, so instead of having just one matrix of parameters of shape (V,C) for the embedding dictionary, you decide to tack on another (C,C) matrix so that we embed like (1,V)(V,C)(C,C). Well unfortunately, by associativity, we can precompute the (V,C)(C,C) matrix product to end up with a (V,C) embedding dictionary again. In this case, adding more parameters does not change our modeling power whatsoever. If you introduce a nonlinearity that prevents that matrix multiply from happening like nonlinear((V,C))(C,C) then that is a different story. There are plenty of fancy nonlinear functions commonly used, but the basic one is relu for rectified linear unit. The output of relu is simple: x<0 -> 0 and x>0 -> x; negative x is clamped to 0 and positive x is the identity. This function is nonlinear because by counterexample relu(-1 + 1) = relu(0) = 0 but relu(-1) + relu(1) = 0 + 1 = 1. In our previous example, we could usefully use another (C,C) matrix by embedding like relu((1,V)(V,C))(C,C). When applied to a vector/matrix, relu is applied to each element. And the derviative of relu, which we need to be able to update our parameters, is very easy: 0 for x<0 and 1 for x>0. It is also very fast to compute. In the context of matrix multiply, applying relu(Ax), which is the dot of every row of A with the vector x, clamps every negative dot product to 0 and lets all the positive dot products pass through. So any feature vector in A that has angle < 90 will result in a nonzero result in that position in the output vector, and zero everywhere else. This gives us an important building block for a 2 layer unit (C,D)relu((D,C)(C,1)) where the (C,1) is an input and we have a (D,C) matrix and a (C,D). Commonly D >= C and in transformers is always (joke) D = 4C. What does this do? Sticking with D = 4C, the first (D,C) matrix is 4C features that we get to check against our input and comparing with the dot product. We then relu these 4C similarity results so that only the positive ones remain, while the others go to 0. We then multiply with a (C,4C) matrix which has C features in a 4C dim space. Each of these features are “features of the similarity results with the first 4C features”. The output of the layer is something like: compare this vector to 4C features, then compare those similarities to these C similarity patterns and the output is a vector which for each coordinate, says how much it matched the corresponding similarity pattern. You can go crazy by repeating this pattern of relu + matmul over and over again and we get what is called a multilayer perceptron or MLP and is where the deep in deep learning comes from. The relu (or other nonlinear “activation” function) prevents the matmuls from just collapsing into one matrix. Q: is there a fused relu-dot like operation? It would relu each pairwise product then sum. This would give you a modified similarity measure where it sums all the products with the same sign. Another related Q is about another sum after a regular relu(Ax) so sum(relu(Ax)) which counts the total “amount of similarity” between the features in A and the vector x.

We now have enough to look at the transformer architecture which is very popular in large language models as a next token prediction function. As we’ve seen, we have some context of tokens (integers) in our vocabulary (T) and we embed them into (T,C) using a (V,C) dictionary. We seek a function f : (T,C) -> (T,V) where each row of the output is a probabilty vector over each item in our vocabulary. Remember we can turn any vector into a pvector with softmax, so whatever (T,V) f might have given us, we can apply softmax to each row of (T,V) to get each row into a pvector. This is slightly different than our previous formulation of this problem where f : (T,C) -> (V) where we only returned a single pvector, what gives? Our new formulation f : (T,C) -> (T,V) gives a next token prediction for every prefix of the input. This is useful because a) we already have that batch of data loaded into memory, we may as well compute on it and use it to update our parameters and b) it lets us run the model on inputs which are less than T tokens long. Q: how useful are the gradients for the short prefixes? We can easily get training data for this function by just taking all the text on the internet and grab random sequences of length T+1, feeding the first T tokens as xi and the last T tokens as yi. Ex. T=4 “abcde” is our example, we want f(“abcd”) -> “bcde” ie f(“a”) -> “b” f(“ab”) -> “c” f(“abc”) -> “d” f(“abcd”) -> “e”. Remember though that the output is a pvector over our whole vocabulary, so the target pvector is a one-hot V-vector ie the probability distribution with all probability on the single letter of our target. We can easily see this objective is complicated when there are more than 1 observed next tokens for a given input, ie we will have many cases where f(c) -> x and f(c) -> y. This means the model’s output pvector in this case will need to put some prob on x and some on y to minimize its loss against all the training data (or in some proportion to how often they appear in the training data). For what follows, I’ll mostly focus on the final next token prediction, as if we had f : (T,C) -> (V) because everything else is the same sans masking (later). So, the most important part of transformer is the attention step, where attn : (T,C) -> (T,C) ie it produces 1 vector for each input vector, but is a function of all (T,C). Contrast this to the mlp or feedforward network we discussed which mlp : (C) -> (C) which we are free to apply at each position so row-wise-mlp : (T,C) -> (T,C) but the computation of each output row only depends on the corresponding input row. Attention is important because it produces an output vector per position that is dependent on the entire input. I suppose we also need to mention that our architecture will be built in an additive fashion, where f(xi) = xi + attn(xi) + … ie resnet style. This is useful because the grad(f(xi)) = grad(xi) + grad(attn(xi) + grad(…) so we get gradient contributions immediately on xi (right to the embedding layer), then some more from the attention layer and so on. I think also some argument about how the trivial solution for attn(xi) = 0 would still map f(xi) -> xi ie it would learn the identity at least gives the model a starting point. Whereas it would have “more to learn” if f(xi) = attn(xi) + … This setup is also called the residual stream where we keep adding functions of the stream back into the stream, like f(xi) = xi + f1(xi) + f2(xi + f1(xi)) + f3(xi + f1(xi) + f2(xi + f1(xi))) + … Another perspective is that we are treating f like the derivative, we want it to tell us, “what should I add to my input vectors such that they match the next token output vector”. Okay so how is attn actually computed? We’ll need a few matrices: S (C,C) V (C,C) and O (C,C) (note this is different introduction than the paper, we’ll get to Q and K in a second). Our goal is to return a single vector (C) that is the weighted sum of the rows of (T,C) under the transform V, then passed through O. So take our input (T,C), apply V to each row (T,C)(C,C) -> (T,C), then take a weighted sum with a vector (T) and pass through O, so we do (C,C)transpose((1,T)(T,C)) -> (C,1). Altogether that looks like (C,C)transpose((1,T)(T,C)(C,C)) -> (C,1). We compute this weighting vector by computing the similarity of the last row of (T,C) with every other row of (T,C), which is where the “interaction” or dependence among the whole input comes from. We could compute this similarity using a dot product, but that is a bit too simple. So we use our S matrix (C,C) to compute the similarity (T,C)(C,C)(C,1) -> (T,1). Having the S matrix lets the vector transform before the dot product is computed. We chuck in a softmax on the similarity result (T,1) which makes it into a pvector and again nonlinear things can model more complicated things than linear ones. It also means the weighting values are nicely behaved in range (0,1). Altogether in my weird notation it looks like (C,C)softmax((T,C)(C,C)(C,1)(T,C)(C,C) or in symbols Osoftmax(XSr)XV where r is the last row. The original presentation of this uses two matrices Q and K instead of S where S can be thought of as QK (C,C)(C,C) -> (C,C), but I like seeing it as a single thing. The reason to separate them in practice is that we use multi-headed attention. Lets choose a dim (H) and a number of heads (N) so C = N * H, now slice up our O, Q, K, V matrices into (N,C,H) (still will have CC = CNH elements). Each head of attention gets the input xi and operates independently of the other heads. For a single head where Qi (C,H) Ki (C,H), we can still compute our similarity using the matrix QiKi.transpose (T,C)(C,H)(H,C)(C,1) -> (T,1). This is effectively a low rank (since H < C) representation of a matrix S, or another way to think of it as projecting the vectors into a lower dim H and then computing their similarity. The head will have computed the value vectors using Vi (C,H) in dimension H so (T,C)(C,H) -> (T,H) and after combining them with a weighted sum, we get (H), but we want (C). We use the output slice Oi (C,H) to map us back up into (C). Each head independently computes a vector (C) from the input xi and those vectors are added together. Using equally sized heads is conveneient for implementation. Q: where is the research on variable head sizes? I would think that distributions would be best represented with some specific collection of n1 heads of size s1, n2 heads of size s2, etc. where your goal is to use the smallest head size possible. Are our current heads the minimum size possible and we just expect deeper models to mimic larger heads? To return to the full case of (T,C) -> (T,C) our attention heads do the same process applied at each position, with the caveat that when we take the weighted sum of the value transformed vectors, we need to zero out the vectors that come from positions forward in time (so called causal and/or decoder / autoregressive). We can also choose to not do this (so called encoder) if we are interested in for example sentiment analaysis where we always look at the full text and only produce one output. Note that some of the dimensions above are a bit sloppy in terms of multiplying on the left/right w/ transpose or not. Also, I presented the QK as S but left O and V separate, even though we can view them similarly as an OV low rank factorization as well; this is the stance from the circuits paper.

Q: what is the relation between QK and a metric tensor? QK gives pos and neg outputs, so not a metric, but exp(XQKX) (transpose appropriately) gives only pos, though we still get something like inverse distance where small attention is like large distance. Though we can fix that with exp(-XQKX). And QK isn’t constrained to be symmetric. My other wonder with QK is that I find it non-symmetric to think about QK acting to transform X on the right into X’ to be compared with X^T and not more like X^Tsqrt(QK)sqrt(QK)X where now each side is being transformed “equally” to meet in the middle.

One thing to notice in attention is that because we take a weighted sum of value vectors, the result is permutation invariant. If we reorder the inputs (scrambling our input tokens), then the similarity weighting we compute with QK “follows” the new order and the sum of attention weighted value vectors discards any ordering information. This is weird because it means the value vector we add back into the residual stream is the same for “ab” as it is for “ba” assuming “a” and “b” are tokens. This doesn’t match our intuition or so called inductive bias because we want the value vector result to be depedent on the tokens AND their ordering. So to do that, we add more vectors! We can either add a learned positional vector (one vector per position ie a (T,C) matrix) or some other computed positional vector dependent on the position of the token. Then we compute on xi+pi instead of just xi. This then means that “a”+p0,”b”+p1 != “b”+p0,”a”+p1. I think the overall goal is how they frame it in rotary positions paper, where we want dot(pos(u, m), pos(v, n)) ~ dot(u, v) * f(m - n) where we want the dot product of the positional vectors to be related to the dot product of the original vectors with some modulation by the relative position. I suppose in the general case it is ~ dot(u, v) * f(m, n) but I think relative position makes the most sense since we don’t want dependence on absolute position, rather the relative position.

Q: if resnet (concept, not the og model) is so great, why isn’t everything residual style? Like when calculating the result of self attention with V, why not do x+Vx? Well one reason is that with multihead, V downprojects so we can’t do x+Vx. So then maybe I’m asking why not O(Dx + VDx) where we downproject, add with a function of the downproject, then upproject with O. Maybe also in the QK?

With all the focus on mechanistic interpretability for how do we reverse engineer why transfomrers/attention work so well in next token prediction, it seems natural to ask why can’t we design networks that are more interpretable to begin with. One part of the interp is to try to isolate/identify modules/heads that seem to do a particular thing. But oh wait that is really hard because we can’t even guarantee a head does a “single” thing or even if it does a single thing, the residual stream isn’t forced to have an interpertable structure or even a non-changing structure between layers. So that all makes me think about how could you train attention heads independently. The training objective isn’t clear, maybe it is the full task but we can’t expect a single head to do that well, but maybe we can find the top N heads with minimum cross correlation or something. This is related to the idea of the task space having an inherent amount or complexity of structure and things like analysis of variance where we’d be interested in finding the minimum set of attention heads using the minimum amount of weights that maximally fits the data, and that task has some pareto frontier. But ultimately we seek to match the true underlying structure of the data in terms of multiplicity of features and dimensionality of features in the single head case. The single heads are the best predictors of the next token given they act independetly, but we can then ask what are the best single heads that produce features that are best for the second layer attention heads to predict the next token. Are those sets of heads the same as the first set of heads? How can you train those heads independently of feeding their results to the second layer of heads? Can you train the second layer of heads independently of the first layer of heads? There is research on module level training instead of end to end. Q: can you start training a model with C hidden dimension, then increase to C+1, C+2, … in some schedule? Are there such things as continuous/fractional dimension so that you could take the derivative wrt dimension number? Ie tell me how much I gain from adding another dimension. One thing that bugs me is that all these nets are invariant to channel permutation, like we’re searching over such a redundant function space! Wouldn’t it be nicer to start with C=Cmin where we posit that Cmin is the smallest dimension of the single maximally informative feature; learn a net that uses this “one” (how can we enforce that?) feature, then move to C = Cmin + Cmin2 channels and freeze the first Cmin features to force it to learn the second most informative feature of minimal size Cmin2. Repeat for your resource budget to desired accuracy. Or you start the other way with C=1 and learn the best 1 dimensional feature, then add another 1, then maybe the best 2 dim feature, etc. up until you get to C=…+Cmax and we might expect there to be a few features of size close to Cmax and many features of size closer to 1 or something small; ie long tail of feature size (or really long head if we order by feature size). If we wanted to imbue each channel with some notion of identity to enforce the channel permutation invariance, maybe we have a layernorm type thing that rescales each channel i with something like Cpow(i+1, a) for some constants C and a. Ie channel 0 is defined to be the most important so we allow it to have the most scale, channel 1 second most important so it has less scale, and the distribution of scales of the parameters might be learned. Or is allowing a configurable variance the right thing here?

Q: Why is the dot product the special one? dot(u, v) = mapreduce(mul, add) but there are the other choices mapreduce(mul, mul) mapreduce(add, mul) mapreduce(add, add)

Q: Is QK positive semidefinite? Should it be? x^TAx >= 0 means we never flip the sign of any channel for all x. Though the exp in softmax turns any of those negatives into small positive anyways.

Symmetry seems to be all the rage. O(d) is the Lie group (group with infinite members) of orthogonal transformations eqv (d,d) ortho matrices (inv(A) == transp(A)) eqv symmetry of a d-circle. SO(d) (special) are the rotations ie excluding the flips. In physics U and SU are unitary for complex. Equivariant wrt a symmetry G is when f(Gx) = Gf(x) ie applying any transformation in the symmetry on the input is the same as applying it on the output. Invariant is when f(Gx) = f(x). first fundamental theorem of invariant theory: f is O(d) invariant iff f(A) = f’(gram(A)) remember gram(A) is all pairwise dot products gram (n,d) -> (n,n). f is O(d) equivariant iff f(V) = sum(gi(gram(V))V) ie a weighted sum of V, where the weights are computed from the gram cucumber. E(d) is the euclidean group which preserves euclid distance, so O(d) + translations. For a point cloud, can get E(d) invariance by subtracting the mean/centroid of the points “canoniclization”. Ie a way to choose a unique representative (similar to choosing the pair (u,v) st u < v or the sorted version of a set when programming. vid1 paper1

diffusion and score-based generative models

From this amazing video

note to self using grad everywhere but remember grad(f : n -> 1) : (1,n) vector of derivatives and jacobian(f : n -> m) : (m,n) matrix of derivatives

  • we have data samples, like images, that we suppose come from some underlying distribution. these samples are very sparse in the full space, ie the density is very low/spread out among all possible images
  • to model the data, we can use a nnet f to predict p(x) but nnet hard to restrict to a pdf. we can always use exp(f(x)) to get nonnegative, but we need to divide by a normalizing term Z to get a true pdf. softmax takes care of this in the discrete case, but is intractable here. Z is computable in closed form for eg gaussian. alternatives to exact normalizing constant are:
    • approximate the constant (energy based models)
    • restricted nnets (autoregreesive, normalizing flow, variational ae)
    • model the generation process only (GAN)
  • for a pdf p(x), the (Stein) score function is grad(log(p(x)), x) (score is a crazy name to me). gradient of the log prob
    • for p(x) = exp(f(x)) / Z, score(p) = grad(log(p)) = grad(f(x)) - grad(log(Z)) and the second term goes to 0
    • so s(x) = grad(f(x)) which we can get with backprop for any nnet f
  • to estimate score from data, need an objective, for score(x) : d -> d
    • replace fisher divergence with score matching because it requires knowing the true score function. but score matching requires trace of jacobian, so 1 backpass per dimension of the input
    • sliced score matching (from sliced fisher div) projects the score (vector field) to random directions. end up with vT grad f(x) v = vT grad vT f(x). ie we dot with a direction vector v on the output of the nnet so we get a scalar output and then again after the backward pass to get a scalar result (on both sides). so then we only need one backprop per sample+direction
    • denoising score matching, derive a noisy distribution q from p by adding gaussian noise and then try to find the score of this dist
  • sampling, assume we’ve learned a good score function s
    • if we directly follow s by integrating random points, we’ll get clusters boo. so we use langevin dynamics to add noise at each step during integration
    • on its own, this gives bad samples b/c our learned score function is very innacurate in low data-density regions (remember real images are very sparse in the space of all images)
    • by adding noise to data points, we can learn a better s in regions around our data points
    • extend by doing this for multiple noise levels to get one s for each noise level
    • extend by doing noise conditional score model where the score model gets the noise level as an input s(x, std) (related to DDPM todo)
  • control the generation process
    • to condition on some input (like a prompt/caption), we want to train a p(x y). expand with bayes rule and compute the score function and we get score(p(x y)) = score(p(x)) + score(p(y x))
    • p(y x) is just a model that takes an image and gives a caption (or whatever we are conditioning on). Can use any model we want in conjunction with our score(p(x))
    • though really we need a p(y x, t) a time dependent classifier (can train a classifier on noised images)
  • probability evaluation
    • train a model to estimate based on varying noise level t [0, T] score(pt(x)) where p0(x) are our data points, pT(x) is a fully noised version and is eqv to eg a gaussian and pt(x) is somewhere in between.
    • sde stochastic differential equation dx = f(x, t)dt + g(t)dw where g(t)dw is the stochastic part
    • for this we choose sde of the form dx = sigma(t)dw, and can derive a reverse time sde which depends on a time dependent/condition score function s(x, t)
    • can derive an ordinary differential equation ode (probability flow ODE) which is a function of s(x, t)
    • lets us compute p(x) for any image using an odesolver to step s(x, t)
  • Q&A says that unet is a common architecture for the actual nnet
  • some more practical investigations in this video
    • stochastic still gives best generation results
    • most training time spent in the middle third of t
    • the nnet needs scaling so that across all noise levels the magnitude of values are similar
    • they use a weighted skip connection so that at low noise levels, they effectively predict noise to correct input, at high noise levels, predict signal to override input
    • the skip connection effectively makes it predict the noise b/c its like img = noisy(img) - noiseof(noisy(img))

structured state spaces

From this video I preferred it to the other mlsys presentation

  • considering signals, a sampled version of some underlying continuous process
  • trying to approximate a 1d signal over time u
    • to do that we want to derive some state x from u that we update over time
    • example is exp moving avg EMA where x = lerp(x, u, alpha) which has infinite context length (summarizes entire context) and is updateable in constant time
    • HiPPO (high order polynomial projection operator) extends this idea to deriving a vector x of desired size which are coeffs to a polynomial basis and can be efficiently updated
    • depends on a measure which defines the weighting of the past signal for how much we care about it
    • state update is an ode (hippo operator) x’(t) = Ax(t) + Bu(t) where A (hippo matrix) is the state transition and B tells us how to mix in the signal
    • you can then plugin the state at x(t) into the poly basis to reconstruct the signal. loss is the squared error (I think?) weighted by the measure
    • “all hippo operators have displacement rank <= 3 so ahve linear-time matrix-vector multiplication” todo what is displacement rank
  • if we use hippo in a layer/model, we are increasing dimensionality from 1 -> d (where d is the degree of poly basis)
    • so instead we output a 1d y = Cx + Du where we have (1,d)(d,1) + (1)(1) so dot the state with a vector and add a scalar multiple of the input (skip connection)
  • the x’ = Ax + Bu and y = Cx + Du is an SSM (in most general form, ABCD can be time dependent)
    • output y(t+1) can be computed online, but when we’re training and we know the whole sequence, would like to avoid
    • whole output y(t) can be computed convolutionally y = conv(u, K) skipping the computation of the state signal x(t) (though computing K is still slow)
      • K(i) = C matpow(A,i) B for i in L so yeah computing K naively is real slow
    • we then discretize ABCD for some step size
  • structured state space model gives structure to the ABCD which lets us compute K efficiently
  • on continuous signals, s4 has good zero-shot perf on resampled signals which tanks discrete methods (though is that passing in the step size into the discretization method?)
  • https://srush.github.io/annotated-s4/
    • A is the hippo matrix, B C step-size and scalar D are learned
    • stack one SSM per input dimension to get a layer from C -> C

Links

  • Attention Is All You Need
    • self/cross attention, transformer
  • Deep Residual Learning for Image Recognition
    • resnet, residual, skip/shortcut connection
    • x + f0(x) + f1(x + f0(x)) + …
  • A Mathematical Framework for Transformer Circuits
    • residual stream, QK/OV circuit
  • Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View
    • numerical integrator, Lie-Trotter splitting scheme and the Euler’s method, Strang-Marchuk splitting scheme
    • attention is like diffusion (interaction) and FF like convection (indepenent particle flow)
    • Macaron layer reorders so we do 1/2FF then Attn then 1/2FF
  • Improving Transformer Models by Reordering their Sublayers
    • related to above where they test loads of variants of (s f)* layers (s for self attention). No scaling on f layers
  • Residual Networks Behave Like Ensembles of Relatively Shallow Networks
    • unroll a deep residual network to view as a sum of paths, like in the circuits paper
    • for -f-g- you get paths ----- -f--- ---g- -f-g-
  • Improving neural networks by preventing co-adaptation of feature detectors
    • dropout
    • randomly zero out parameters during training, average outputs at test time. Summing log-probs is same as geometric mean of “experts”
    • but then they also do a weird renormalization of weights (or is it inputs, I’m confused) if they exceed an L2 constraint. and a weird initialization
    • the paper uses p for probability of element to be 1, and pytorch uses p for probability of element to be 0, so paper multiplies by 1/p in the forward pass and torch does 1/(1-p)
  • Layer Normalization
    • confusing b/c they present it as though its norming a layer’s weights (like weight normalization, where w: vector = vg/norm(v) and v and g are learned), but in torch it just acts on the data passing through. I find the neuron focus confusing
    • interesting they motivate it as a way to speed up learning
    • as a post layer norm, like layernorm(mlp(x)), it makes the output invariant to scaling of the mlp’s matrix
    • as a pre layer norm like mlp(layernorm(x)) I think it makes the output invariant to scaling of the data
    • post layer norm was the original architecture in transformer, but pre layer norm is the gpt2+ standard
    • table 1 gives more invariance properties
    • pre layer norm seems to make sense b/c (I think) your layers experience less covariate shift (from shimodaira 2000 via batch norm), as your layers learn, you don’t want a doubly constantly moving target, obv there will be refinement in the residual stream representation, but don’t make it harder than it needs to be with shifts or scaling to worry about too
    • out = (x - mean(x) / sqrt(var(x) + eps)) * w + b where w and b are learned. applied to each input independently (or configurable by shape)
  • Understanding and Improving Layer Normalization
    • adanorm: y = f(centerscale(x)) * x where centerscale(x) = (x - mean(x))/std(x) and f(x) is uniquely C(1-kx) by their requirements for constants C and k
    • derivatives of mean and variance ie backward more important than forward normalization
    • w and b parameters of layernorm can lead to overfitting
    • 2.1 eqn 1 defines h as dot(g, N(x)) + b but that has to be a typo right? must mean hadamard (entrywise product)
  • Seeing is Believing: Brain-Inspired Modular Training for Mechanistic Interpretability
    • regularize overall layer weights by d_ijk|w_ijk| (locality; 2d euclidean distance between neuron (rows) after laying out layer neurons (matrix)) and occasionally trying row swapping within a layer if it minimizes this loss too since they say it gd gets stuck in a local minima
    • pretty pictures
    • the mnist pictures show the last layer of digit “neurons” (I suppose its a fine term when visualizing rows as dots but I still have an aversion to it) with one dot per digit arranged in a circle; did they use a circular layout in the distance calculation? I guess it is because the location of some digits change (see last page). But idk if that is a useful layer to have positions on; I guess if the data were skewed so all 9’s were in the top right and all 1’s were in the bottom left, then maybe. The dots layout is still confusing me a bit b/c for example the input layer is a 2d grid of dots where each dot is a scalar, but then the actual layers are one dot per row/matrix (for 2d/3d) right?
  • Language Modeling with Gated Convolutional Networks
    • gated linear unit, GLU
    • glu = mul(relu(Wx), Vx), bilinear = mul(Wx, Vx), lstm style (gated tanh unit) gtu = mul(tanh(Wx), relu(Vx)) leads to vanishing gradients b/c you get tanh’ and sigma’ on both (respective) factors of chain rule
    • they call this is a multiplicative skip connection for the gradient. todo understand the gradient calculation better since it seems they ditch the
    • why does google rank the “glu variants improve transformer” paper over this one
    • fig 1 diagram shows taking half the vector for W and the other half for V, like glu = mul(relu(W@left(x)), W@right(x)), reminds me of affine coupling layer (flow matching)
  • NICE: Non-linear Independent Components Estimation
    • affine coupling layer, u, v = split(x); y = concat(u, mul(v, f(v)) + g(v))
    • didn’t read much, came via vid
  • HiPPO: Recurrent Memory with Optimal Polynomial Projections
    • todo
  • https://github.com/lucidrains/x-transformers
  • https://github.com/karpathy/minGPT/blob/master/mingpt/model.py