+++ title = “Prototype-based learning. Part I: GMLVQ from scratch” date = 2024-08-10 tags = [“Python”, “Data Analysis”, “For fun”] categories = [“Projects”] +++

Description

psychology + data science: I found lvq and saw the connection with human category learning. I decided to implement it

two parts: implementation of all operations from scratch using pytorch and implementation from scratch but abstracting the optimization/learning part -> use of all pytorch features

Problem

category learning approaches: exemplar vs prototypes. Prototype learning assumes “that a category of things in the world (objects, animals, shapes, etc.) can be represented in the mind by a prototype. A prototype is a cognitive representation that captures the regularities and commonalities among category members and can help a perceiver distinguish category members from non-members”. LVQ can serve as a model of human learning, why? -> learns prototypes and classifies examples based on comparison with prototypes.

What is different and interesting?

prototypes are learned, they might be summaries of data but not necessarily the average. It allows more than one prototype per category. Similarity function is generalized euclidean distance / mahalanobis distance and can also be learned. It does not need to be unique, we can one similarity function per prototype. This flexibility allows lvq to learn linearly separable categories but also non linearly separable categories, which was is one of the flaws of the standard prototype model

objective

Despite the potential as cognitive model, i have not proper means to test it (Access to data from experiments with human participants), so i Will focus on its usefulness as a ml classifier. More specifically, in this first part i’ll try to replicate the resuts of the paper xxx, which the one i first used to learn about this family of models.


most articles i’ve seen about lvq models tend to use a nomenclature that i personally find a bit difficult to follow when it’s used to explain certain aspects of the algorithm such as a the updating rules. For that reason I decided to make a few changes in the way the model is defined.

  • By default, vectors are row vectors. This is due to the fact that, commonly, datasets are shaped \(n\times p\) (number of observations \(\times\) number of features)
  • \({\bf x}\) = data point: \(\begin{bmatrix} x_1 & \dots & x_p \end{bmatrix}\)
  • \({\bf w}\) = prototype: \(\begin{bmatrix} w_1 & \dots & w_p \end{bmatrix}\)
  • \({\bf R}\) = relevance matrix. In GMLVQ, it’s a dense symmetric matrix of relevance/attention weights given to features. where \(tr({\bf R}) = 1\). If diagonal, the model will be RLVQ and won’t account for correlations between features. If identity matrix, GLVQ.
  • \({\bf R}={\bf Q}^\intercal{\bf Q}\)

How does the model work? A trained lvq model works by comparing and unseen exemplar \({\bf x}\) to each of the learned prototypes stored in memory \({\bf w}_1,...,{\bf w}_C\). There may be one for each of the \(C\) classes, as here, or more than one. The comparison is done using a dissimilarity function \(d\) defined as:

\(d({\bf x},{\bf w}_k,{\bf Q})=({\bf x}-{\bf w}_k) {\bf R} ({\bf x}-{\bf w}_k)^\intercal\)

In terms of \({\bf Q}\):

\(d({\bf x},{\bf w}_k,{\bf Q})=({\bf x}-{\bf w}_k) {\bf Q}^\intercal{\bf Q} ({\bf x}-{\bf w}_k)^\intercal=\lVert ({\bf x}-{\bf w}_k)\;{\bf Q} \rVert_2\)

This version is the one we are actually going to use for training purposes

The decision rule is to assign to the exemplar the class corresponding to the closest prototype:

\(\hat c=\underset{k\in C}{\operatorname{argmin}} d({\bf x},{\bf w}_k,{\bf R})\)

It’s fairly easy to easy how the assumptions on \({\bf R}\) change the final model.

  • If it’s dense and unique, we have the explained GMLVQ.

  • If it’s diagonal, then \(d({\bf x},{\bf w}_k,{\bf R})=\sum_{i=1}^p r_i (x_i - w_{ki})^2\), which is the squared euclidean metric used for RLVQ.

  • If \({\bf R}={\bf I}\), then \(d({\bf x},{\bf w}_k,{\bf R})=({\bf x}-{\bf w}_k)({\bf x}-{\bf w}_k)^\intercal=\lVert ({\bf x}-{\bf w}_k) \rVert_2\), the basic euclidean distance of GLVQ.

  • Finally, if our relevance matrix is not unique and have one for each prototype, then we have a localized version of the model (LGMLQ or LRLVQ).

To see more clearly how the relevance matrix affects the distance calculations, we can simulate and extremely basic example where we have bidimensional instances:

\({\bf x}=\begin{bmatrix} x_1 & x_2 \end{bmatrix}\)

\({\bf w}_k=\begin{bmatrix} w_1 & w_2 \end{bmatrix}\)

\({\bf R}=\begin{bmatrix} r_{11} & r_{12} \\ r_{12} & r_{22} \end{bmatrix}\)

Notice how we only have \(r_{12}\) and not \(r_{21}\). That is because, as mentioned, the matrix is symmetric, so \(r_{12}= r_{21}\). The distance between the instance and the prototype would be:

\[ \begin{aligned} d({\bf x},{\bf w}_k,{\bf R})&=({\bf x}-{\bf w}_k) {\bf R} ({\bf x}-{\bf w}_k)^\intercal \\ &=\begin{bmatrix} \delta_1 & \delta_2 \end{bmatrix} \begin{bmatrix} r_{11} & r_{12} \\ r_{12} & r_{22} \end{bmatrix} \begin{bmatrix} \delta_1 \\ \delta_2 \end{bmatrix} \\ &= \begin{bmatrix} (\delta_1 r_{11} + \delta_2 r_{12}) & (\delta_1 r_{12} + \delta_2 r_{22}) \end{bmatrix} \begin{bmatrix} \delta_1 \\ \delta_2 \end{bmatrix} \\ &=\delta_1^2 \cdot r_{11} + \delta_2^2 \cdot r_{22} + \delta_1 \cdot \delta_2 r_{12} \end{aligned} \]

where \(\delta_i=x_i-w_i\).

Now it’s easier to see how the choice of the model/relevance matrix impacts distance:

\(\text{GLVQ} \; (r_{12}=0 \; ; \; r_{ii}=1): d({\bf x},{\bf w}_k,{\bf R})=\delta_1^2 + \delta_2^2\)

\(\text{RLVQ}\; (r_{12}=0): \; d({\bf x},{\bf w}_k,{\bf R})=\delta_1^2 \cdot r_{11} + \delta_2^2 \cdot r_{22}\)

\(\text{GMLVQ} \; : d({\bf x},{\bf w}_k,{\bf R})=\delta_1^2 \cdot r_{11} + \delta_2^2 \cdot r_{22} + \delta_1 \cdot \delta_2 r_{12}\)

\(\text{LRLVQ} \; : d({\bf x},{\bf w}_k,{\bf R})=\delta_1^2 \cdot r_{11}^k + \delta_2^2 \cdot r_{22}^k\)

\(\text{LGMLVQ} \; : d({\bf x},{\bf w}_k,{\bf R})=\delta_1^2 \cdot r_{11}^k + \delta_2^2 \cdot r_{22}^k + \delta_1 \cdot \delta_2 r_{12}^k\)

Notice how in the localized versions the relevance weights depend on the prototype used in the comparison. That’s exactly what makes LVQ so powerful, to the point to achieve good classification rates of categories with nonlinear boundaries.

How does the model learn?

One of the things that also attracted me from these lvq models is the fact that parameters can be optimized using gradient descent. This is specially attractive for me because is the optimization method that i understand the best. The loss function is the relative distance loss, which was first used to develop the GLVQ model, and is defined as:

\(\mathcal{L}(d^+,d^-)=\phi\Bigl( \frac{d_+-d_-}{d_++d_-} \Bigr)\)

where \(\phi\) is a function that can add an additional transformation. Here we are going to assume, as in the paper I’m using as reference, that \(\phi(x)=x\), but in the second part we’ll try some other functions such as the sigmoid or ReLU. Notice that we have \(d^+\) and \(d^-\). Those distances are the distance of the data point from the closest prototype with the correct label (\({\bf w}^+\)) and he distance of the data point from the closest prototype with the incorrect label (\({\bf w}^-\)), respectively (from the energy-based model learning point of view, the latter would correspond to the “most offending answer”, a name that i found quite funny). In the “localized” model scenario, we would also have \({\bf R}^+\) and \({\bf R}^-\)

With stochastic gradient descent (in the second part we’ll work with batch gradient descent):

\({\bf w}^+_{t+1}={\bf w}^+_t-\eta \cdot \frac{\partial \mathcal{L}}{\partial {\bf w}^+_t}\)

How to get the derivatives for every parameter? For that is convenient to first decompose the output of the loss in the “forward pass”:

\(\mathcal{L}(\mu)=\mu\)

\(\mu(d^+,d^-)=\frac{d^+-d^-}{d^++d^-}\)

\(d({\bf x},{\bf w},{\bf Q})=({\bf x}-{\bf w}) \;{\bf Q}^\intercal {\bf Q} \; ({\bf x}-{\bf w})^\intercal=\lVert ({\bf x}-{\bf w})\;{\bf Q} \rVert_2\) (\({\bf R}\) is not learned directly)

Then the derivatives of each step are:

\(\frac{\partial \mathcal{L}}{\partial \mu}=1\)

\(\frac{\partial \mu}{\partial d^+}=\frac{(d^++d^-)-(d^+-d^-)}{(d^++d^-)^2}=\frac{2d^-}{(d^++d^-)^2}\)

\(\frac{\partial \mu}{\partial d^-}=\frac{-(d^++d^-)-(d^+-d^-)}{(d^++d^-)^2}=\frac{-2d^-}{(d^++d^-)^2}\)

\(\frac{\partial d}{\partial {\bf w}}=-2{\bf Q}^\intercal {\bf Q} ({\bf x}-{\bf w})^\intercal\)

\(\frac{\partial d}{\partial {\bf Q}}=2{\bf Q} ({\bf x}-{\bf w})^\intercal ({\bf x}-{\bf w})\)

And by chain rule:

\(\frac{\partial L}{\partial {\bf w^+}}=1 \cdot \frac{2d^-}{(d^++d^-)^2} \cdot \Bigl(-2{\bf Q}^\intercal {\bf Q} ({\bf x}-{\bf w^+})^\intercal\Bigr)\)

\(\frac{\partial L}{\partial {\bf w^-}}=1 \cdot \frac{-2d^-}{(d^++d^-)^2} \cdot \Bigl(-2{\bf Q}^\intercal {\bf Q} ({\bf x}-{\bf w}^-)^\intercal\Bigr)\)

\(\frac{\partial L}{\partial {\bf Q^+}}=1 \cdot \frac{2d^-}{(d^++d^-)^2} \cdot 2{\bf Q} \; ({\bf x}-{\bf w}^+)^\intercal ({\bf x}-{\bf w}^+)\)

\(\frac{\partial L}{\partial {\bf Q^-}}=1 \cdot \frac{-2d^-}{(d^++d^-)^2} \cdot 2{\bf Q} ({\bf x}-{\bf w}^-)^\intercal ({\bf x}-{\bf w^-})\)

If GMLVQ (only one relevance matrix):

\(\frac{\partial L}{\partial {\bf Q}}=\frac{\partial L}{\partial {\bf Q}^+} + \frac{\partial L}{\partial {\bf Q}^-}\)

Now we have everything prepared to implement gradient descent. We can make some tweaks such as including a decaying learning rate, but that’s extra (I included it anyway, to be consistent with the paper).

Other architectural details

How we initialize parameters? That’s a key question and, in my a case, I followed the suggestions of PAPER. For the prototypes, each vector associated with class \(k\) consists of the average of a random sample of \(z\) observations of class \(k\), with \(z\) being equal to: 1/number of prototypes for class \(k\) \(\times\) total number of observations in class \(k\). For the simple case of one prototype per class this means \(z=\) total number of observations in class \(k\), so prototypes are initialized as class conditional means.

In the case of \({\bf Q}\), in the first iteration is just a diagonal matrix with all its diagonal elements set to: \(\sqrt{\frac{1}{p}}\)