This is an attempt to write down my understanding of the paper “Representation Learning with Contrastive Predictive Coding” by Oord et al.

Ultimate Goal

The primary goal is to learn speech representation that can be further used on numerous downstream tasks such as ASR, ST etc. Now, language models typically learn this sort of representation through next token prediction. The model produces a latent representation from which the output distribution for the next token is generated. The loss function is usually a cross-entropy function.

This approach is not suitable for speech because the speech signal is high dimensional whether it is raw wave signals or log mel-spectrograms. And high dimensional signals do not produce strong results for such a task.

The Re-oriented Approach

Assume we have continuous signals X=x1,...,xM\vec{X} = \\{\vec{x_1},...,\vec{x_M}\\}. Each xix_i is a vector of signal representation. Given xtx_t, we will now predict the latent representation for xt+kx_{t+k}, where k>tk > t. This is presented in the paper in the following way: instead of modelling p(xt+kct)p(\mathbf{x_{t+k} | c_t}), this approach will model fk(xt+k,ct)=exp(zt+kTWkct)f_k(\mathbf{x_{t+k}, c_t}) = exp(\mathbf{z_{t+k}^T \cdot W_k \cdot c_t}) which is a density ratio between the context at ttht^{th} step to the latent representation at t+ktht+k^{th} step. Here, ctc_t is the context representation and ztz_t is the latent representation. In some way, we are looking for the relationship between the two and our loss function will reflect that more clearly.

Training

Forward

We have an encoder layer that will generate latent representation zt\mathbf{z_t} and another auto-regressive layer which will generate ct\mathbf{c_t} from zt\mathbf{z_t} and cp<t\mathbf{c_{p < t}}.

backward

Instead of directly predicting the latent representation, the model will learn to contrast between the latent representations. We want the predicted latent representation for the current time step to be similar to the kthk^{th} tokens after it and different from other representations. We call them positive and negative samples respectively.

The model is given NN samples among which N1N-1 are negative samples. The positive sample is drawn from a regular conditional model p(xt+kct)p(\mathbf{x_{t+k} | c_t}) (this can be an n-gram model). The negative samples are drawn from p(xt+k)p(\mathbf{x_{t+k}}) - which is an overall collection of all the possible values (like the vocab in a text LLM).

LN=Elogfk(xt+k,ct)xjXfk(xj,ct)=Elogexp(zt+kTWkct)xjXexp(zjTWkct)=Elogcosine similarity between current context repr. and future true latent repr.sum of the cosine similarities between all the pos&neg future samples \begin{alignat}{2} L_N = - \Bbb{E} log \frac{f_k(\mathbf{x_{t+k}, c_t})}{\sum_{x_j \in \mathbf{X}} f_k(\mathbf{x_j, c_t}) } \\ = - \Bbb{E} log \frac{exp(\mathbf{z_{t+k}^T W_k c_t})}{\sum_{x_j \in \mathbf{X}} exp(\mathbf{z_j^T W_k c_t})} \\ = - \Bbb{E} log \frac{cosine\ similarity\ between\ current\ context\ repr.\ and\ future\ true\ latent\ repr.}{sum\ of\ the\ cosine\ similarities\ between\ all\ the\ pos \& neg\ future\ samples} \end{alignat}

This loss penalizes the model if the cosine similarity between true latent representation is larger compared to all the cosine similarity it calculated The only way for the model to perform well is to learn to contrast between the positive sample and the other N1N-1 negative samples.

Some Comments

In the loss function, WkW_k is different for each kk. So, we can swap the weights afterwards if we want the model to perform the prediction for the next step or kk steps after that. Even though we are sampling NN samples, whether be it positive or negative, the loss is calculated from the latent representation of that x\mathbf{x}

References

Oord, A. van den, Li, Y., & Vinyals, O. (2019). Representation Learning with Contrastive Predictive Coding (arXiv:1807.03748). arXiv. http://arxiv.org/abs/1807.03748 \

E. Hinton, G. (2013). LEARNING DISTRIBUTED REPRESENTATIONS FOR STATISTICAL LANGUAGE MODELLING. Retrieved November 16, 2024, from http://www.cs.utoronto.ca/%7Ehinton/csc2535/notes/hlbl.pdf

Comments