Site icon Ebolgo

Deep representation learning using layer-wise VICReg losses

Deep representation learning using layer-wise VICReg losses

The evaluation of the layer-wise trained model with the proposed approach is done in two ways. First, by training a linear classifier using the representations from the pre-trained model (by freezing all the pre-trained layers). Several subsets of processed datasets are used to evaluate the efficiency of the trained model. Second, to estimate performance, classification accuracy between a baseline trained from scratch, and a fine-tuned model (all layers of the pre-trained model are fine-tuned) trained using the same subsets. The finding suggests that the linear classifier works well, meaning that it has learned important features. Furthermore, the training procedure yields approximately 7%, 16%, 1%, and 7% accuracy gains compared to baselines when the models are fine-tuned on small subsets of MNIST, EMNIST, Fashion MNIST, and CIFAR-100, respectively. While more labeled samples reduce the need for the proposed training procedure, it will be useful when having a small labeled dataset but a large number of unlabeled samples, since it still improves performance with limited labels.

Experiment on MNIST

We use clustering quality metrics to evaluate the feature informativeness of the pre-trained model. Clustering quality is assessed from the representations of the pre-trained model, and classification performance is evaluated using different-sized subsets of the labeled data.

Learned representation quality assessment with cluster quality metrics: After training the model layerwise on MNIST for 50 epochs with weights (25, 25, 1) for the loss terms variance, invariance, and covariance respectively, we obtain the cluster quality metrics shown in Table 4. The DB and CH indices assess clustering quality across layers, with each metric offering a different perspective on clustering performance. The DB index measures how well the clusters are separated from each other, with lower DB values indicating better clusters. This is achieved by evaluating the ratio of the intra-cluster dispersion to the inter-cluster separation. The gradual reduction in the DB index across layers suggests that the deeper layers form more compact clusters with lower intra-cluster variance, resulting in refined and informative representations.

Table 4 Cluster quality metrics for each layer of the model trained on MNIST.

The CH index measures the ratio of the sum of between-cluster dispersion to within-cluster dispersion. Higher CH values indicate better clustering, with well-separated and compact clusters. As shown in Table 4, the CH score peaks at the third layer, indicating improved representation learning during training. However, the slight dip at the last layer suggests that deeper layers may sacrifice some separation for more abstract representations, as expected in hierarchical models.

Figure 4a shows the epochs versus the DB index values in each layer as training progresses for 50 epochs. The DB index of each layer converges at a point near epoch 50, indicating deeper layers are now able to retain similar information as lower layers despite having fewer neurons. Figure 4b shows the loss minimization per layer during training on MNIST using the proposed approach. Note that although lower layers have lower DB index values, they exhibit a slightly higher loss due to their high-dimensional output embeddings.

Fig. 4
figure 4

Performance of the model trained on MNIST as training continues.

Testing influence of learned representations in downstream task: The pre-trained model is then fine-tuned with fine-tune data subsets of different sizes, ranging from 500 samples to 5000 samples. A baseline model with the same architecture is also trained from scratch on similarly sized fine-tune datasets of 500, 1000, 2500, and 5000 samples. Training is performed with an initial learning rate of 0.0005 and a batch size of 128.

Fig. 5

Training accuracy curve of the baseline model trained on 500 MNIST samples.

Fig. 6

Training accuracy curve of the linear classifier trained on 500 MNIST samples.

Fig. 7

Training accuracy curve of the fine-tuned model using 500 MNIST samples.

An exponential decay learning rate scheduler with a decay step of 50 and a decay rate of 0.95 is used. Instead of training for a fixed number of epochs, early stopping with a patience of five for fine-tuning and baseline is used to allow the model to stop at the optimal point. For the linear evaluation, we set the initial learning rate to 0.005 and patience to 10, since only the SoftMax layer is trained, and the model may underfit if not enough epochs are utilized. The training curves in Figures 5,6, 7 prove that neither of the models overfits the dataset. Although the gap between the validation and training curves is large, this is because only a small amount of labeled data is available for the model to learn.

Table 5 shows the macro average precision, recall, and F1 score of MNIST classification on each of the subsets with different numbers of samples, rounded to the nearest whole percent. The macro average of the metrics is used because the dataset is well-balanced. The uniformity in the metrics is evident from the table. This indicates that the trained models are consistent in predicting instances from each class, suggesting that the models are balanced and reliable.

Table 5 Precision, recall and F1 score comparison of MNIST classification.

Table 6 presents the performance comparison of various model versions trained on subsets of different sizes from the MNIST dataset. With just 500 samples, fine-tuning achieves an accuracy of 85.13%, while the baseline reaches only 77.78%. Similarly, for 1000 samples, fine-tuning results in 88.66% accuracy, whereas the baseline lags at 81.33%. When there is limited data available (around 500 or 1000 samples), fine-tuning delivers impressive results. It achieves up to 7% higher accuracy than the baseline model. As we increase the sample size to 2500 and 5000, the advantage of fine-tuning becomes less pronounced, but it still maintains a slight edge in terms of accuracy and efficiency. For 2500 samples, fine-tuning achieves an accuracy of 91.69%, whereas the baseline manages 88.52%, and with 5000 samples, fine-tuning reaches an accuracy of 94.23%, compared to the 93.68% accuracy of the baseline. This demonstrates that the proposed approach excels in a low-data regime, offering better generalization and efficiency. Traditional training methods become more competitive as data availability grows. Figure 8 shows the confusion matrices of our baseline and fine-tuned models on an MNIST subset of size 500.

Table 6 Accuracy comparison of MNIST classification.
Fig. 8

Confusion matrices of baseline ( a) and fine-tuned (b) models trained on 500 MNIST samples.

The baseline model and fine-tuned model with 500 samples for 50 epochs without using early stopping or any other automatic convergence criterion are shown in Figure 9. Learning of the baseline model plateaus at an early epoch due to insufficient data, halting its ability to improve. In contrast, the trained model with the proposed approach and then fine-tuned with that small labeled data outperforms the baseline by a significant margin of around 7% with limited labeled data. As illustrated in the plot, the accuracy of the baseline model plateaus around 75%, whereas the fine-tuned model continues to improve, ultimately reaching above 85% training accuracy, meaning that initially learned representations via the proposed approach are useful.

Fig. 9

Accuracy comparison between baseline model and fine-tuned model on MNIST with 500 samples.

Experiment on EMNIST

The pre-trained EMNIST model is evaluated by the same approach as discussed earlier for the MNIST experiment. Evaluation reports and discussions are given below for the EMNIST experiment.

Learned representation quality assessment with cluster quality metrics: The same loss term weights (25 for variance, 25 invariance, and 1 for covariance) are applied for the loss terms while training the model with the EMNIST dataset. Table 7 presents the cluster quality metrics for each layer. The DB index and CH score are used to assess the quality of the clusters. The decreasing trend of the DB index across layers indicates that the deeper layers of the model successfully capture meaningful features, ensuring more compact and well-separated clusters. Similarly, the CH score stabilise in the deeper layers, reflecting the model’s capability to manage the higher complexity of EMNIST. These observations confirm that our layer-wise VICReg training remains effective even when applied to larger, multi-class datasets like EMNIST. Figure 10a shows epochs versus DB index values, while Figure 10b shows epochs versus loss minimization of each layer while training on EMNIST.

Table 7 Cluster quality metrics for each layer of the model trained on EMNIST.
Fig. 10

Performance of the model trained on EMNIST as training continues.

Testing influence of learned representations in classification task: A pre-trained model (trained on unlabeled EMNIST samples) is fine-tuned and a baseline model with the same architecture is trained on five fine-tune subsets containing 1000, 1500, 2000, 5000, and 9400 samples to get a comparison of how well the proposed approach works. The initial learning rate is set to 0.0005 with a batch size of 256. A larger batch size is used because EMNIST consists of many classes, and each batch should ideally contain enough samples from each class. As used in the MNIST experiment, exponential learning rate decay and early stopping (with the same configuration) are utilized. Figures 11, 12, 13 shows the training accuracy curves of the baseline, linear classifier, and fine-tuned model. Table 8 presents the rounded macro average precision, recall, and F1 score of EMNIST classification on each subset. The closeness of the three metrics indicates consistent performance across all classes.

Table 9 presents the accuracy and number of epochs for the models trained on EMNIST dataset subsets. When labeled data is very limited (e.g., 1000 or 1500 samples), fine-tuning the pre-trained model helps achieve around 16% accuracy gain compared to baseline models trained on the same subsets. Specifically, only 1000 labeled samples, the baseline model reaches 40.52% accuracy, whereas our fine-tuned model reaches 56.71% accuracy.

Fig. 11

Training accuracy curve of the baseline model trained on 1000 EMNIST samples.

Fig. 12

Training accuracy curve of the linear classifier trained on 1000 EMNIST samples.

Fig. 13

Training accuracy curve of the fine-tuned model trained on 1000 EMNIST samples.

Table 8 Precision, recall, and F1 score comparison of EMNIST classification.
Table 9 Accuracy comparison of EMNIST classification.

For 1500 samples, the baseline model achieves 42.26% accuracy, while the fine-tuned model reaches 59.03%. As discussed earlier, sufficient labeled data helps to reach comparable performance with the proposed method. However, due to the efficiency of the pre-trained model, it rapidly achieves a high accuracy within just a few epochs. Figure 14 shows the confusion matrices of our baseline and fine-tuned model outputs. It shows that the distribution of predictions is scattered compared to the fine-tuned model, where the distribution of the diagonal element is prominent and less scattered, meaning that it is better than the baseline and the prediction is more accurate.

Fig. 14

Confusion matrices of baseline (a) and fine-tuned (b) models trained on 1000 EMNIST samples.

To compare, a baseline model is trained with 1000 samples for 50 epochs without using early stopping or any automatic convergence criteria. The pre-trained EMNIST model is then fine-tuned with the same number of samples for 50 epochs and compared with the results shown in Fig. 15. Learning of the baseline model plateaus early due to insufficient data, halting its ability to improve. In contrast, the fine-tuned model outperforms the baseline by a significant margin. As shown in the graph, the accuracy of the baseline model plateaus around 40%, while the fine-tuned model continues to improve, ultimately reaching training precision above 50%.

Fig. 15

Accuracy comparison between baseline and fine-tuned models on EMNIST with 1000 samples over 50 epochs.

Experiment on more complex datasets

The model is subsequently evaluated on more complex MNIST datasets, namely Fashion MNIST and CIFAR-100. The same weights for variance, invariance, and covariance loss terms are utilized to train models on the Fashion MNIST and CIFAR-100 datasets. Tables 10 and 11 show the cluster quality metrics scores after training the model for 50 epochs on Fashion MNIST and CIFAR-100, respectively. Despite the increased complexity of the datasets, the proposed approach is still able to capture useful features, as evident from the cluster quality metrics.

Table 10 Clustering quality ,metrics for each layer of the model trained on fashion MNIST.
Table 11 Clustering quality metrics for each layer of the model trained on CIFAR-100.

The pre-trained models are then fine-tuned on the respective datasets (Fashion MNIST and CIFAR-100) and compared to baselines with identical architectures and trained on the same subsets. For Fashion MNIST, four subsets of size 500, 1000, 2500, and 5000 samples are used, while for CIFAR-100, three larger subsets (2500, 3500, and 5000 samples) are used due to the complexity and number of classes in the dataset. CIFAR-100 subsets are chosen to be larger because there are 100 classes, and there must be enough samples from each class in the subsets. A dropout layer with a dropout rate of 50% is added right before the softmax layer to minimize the risk of overfitting, given the small subset sizes and dataset complexity. A learning rate of 0.0001 is used for both models during fine-tuning with an exponential decay learning rate scheduler (decay step: 50 and a decay rate: 0.95). The model is fine-tuned for 200 epochs with a batch size of 128. The baseline is trained under the same configurations and data subsets to ensure a valid comparison.

Table 12 compares the efficiency of the pre-trained model on Fashion MNIST with limited labeled data. From the table, it is evident that fine-tuning the pre-trained model with a very few samples can lead to around 1% accuracy gain. Additionally, adding a softmax layer to the pre-trained model and only training that layer can achieve 55% to 66% accuracy with just the same small subsets, demonstrating that the representations learned during pre-training are effective.

Table 12 Accuracy comparison of fashion MNIST classification.

In the CIFAR-100 experiment, accuracies are low due to the complexity of the dataset and the use of a fully connected neural network composed only of dense layers. A dense-only network cannot fully capture the complex patterns in that dataset. Nevertheless, the approach still helps the model achieve better accuracy. For example, the model fine-tuned with just 500 samples (an average of five samples from each class) gets 12.03% accuracy compared to the baseline (trained on the same subset) manages to get only 7.05%. With a simple architecture and a very small labeled dataset, the approach improves accuracy by approximately 5% to 8% in the CIFAR100 experiment as shown in Table 13.

Table 13 Accuracy comparison of CIFAR-100 classification.

Figures 16, 17 shows the comparison of the training accuracy of the baseline model and the pre-trained model on the Fashion MNIST and CIFAR-100 datasets. While the baseline model’s accuracy lags, the pre-trained model demonstrate a steady increase in accuracy throughout training.

Fig. 16

Accuracy comparison between baseline and fine-tuned models on Fashion MNIST with 1000 samples over 50 epochs.

Fig. 17

Accuracy comparison between baseline and fine-tuned models on CIFAR100 with 1000 samples over 50 epochs.

Comparison with self-supervised learning and layer-wise training approaches

The proposed VICReg layer-wise method focuses on enhancing cluster and feature representation quality through a layer-wise optimization strategy. This approach ensures that features remain both compact and well-separated at each layer, as validated by the DB and CH metrics, making it highly effective for multi-class tasks such as MNIST and EMNIST. In comparison, the Forward-Forward (FF) layer-wise algorithm aims to eliminate backpropagation by using local optimization at each layer, making it suitable for resource-constrained hardware and neuromorphic computing. As shown in Table 14, FF’s scalability and transfer performance are limited compared to VICReg’s progressive representation refinement. Similarly, the Forward-Forward in a self-supervised setting (SSL) focuses on learning representations without labels through a goodness-based loss at each layer. While it aligns with VICReg’s layer-wise training concept, it underperforms in transfer learning tasks, especially on more complex datasets like CIFAR-10 and SVHN. The Probabilistic SSL approach (SimVAE) prioritizes style-content retention in its generative framework, making it better suited for fine-grained tasks requiring nuanced features. However, VICReg’s structured clustering optimization provides a practical advantage for tasks needing clear feature separation. Finally, SimCLR and MoCo, with their contrastive learning frameworks, excel in transfer learning but require large datasets and complex augmentations, which can be resource-intensive. In contrast, VICReg achieves robust performance through structured layer-wise optimization, offering a scalable and efficient solution for multiclass classification. This comparison, summarized in Table 14, highlights how VICReg bridges the gap between layer-wise optimization and effective representation learning, making it a strong competitor across various applications. The study explores solving MNIST and Fashion MNIST classification problems with various SSL methods and the SimVAE44. Whereas these methods involve various advanced architectures, our method outperforms or achieves comparable performance with a simple architecture and very few labeled data. Table 15 shows the comparison between various SSL methods and our proposed method.

Table 14 Comparison with related works.
Table 15 Classification accuracy comparison with various SSL methods.

Discussion on computational efficiency and scalability

Pre-training the model takes approximately three to six seconds per epoch, depending on the complexity and magnitude of the datasets. It takes longer than fine-tuning the model because pre-training the model involves calculating VICReg loss at each layer and local updates. However, this approach eliminates the need for massive, labeled data and the use of backpropagation during pretraining. With very little labeled data, it can easily gain impressive accuracy compared to the same small dataset due to the proposed pre-training step. Table 16 shows the time taken (in seconds) for each scenario: linear, baseline, and fine-tune.

Table 16 Average time taken at each epoch (in seconds).

However, our current experiments involve MLPs as a proof of concept for layer-wise training with local VICReg losses; this approach can potentially scale to larger architectures and datasets. VICReg is an architecture-agnostic approach, showing successful integration with deeper architectures, including ResNet and Transformers2,23. A recent study shows that large convolutional networks can be trained well with local losses like VICReg47. It demonstrates that deep networks like ResNet-50 can achieve nearly similar performance with local objectives, meaning our conceptually similar approach can achieve effective performance as well. In this work, we focus on MLPs to prove our concept, and we plan generalizability studies for future studies.

Representation space evolution during training

To visualize the evolution of representation spaces during training, the technique called t-distributed Stochastic Neighbor Embedding (t-SNE) is used, which is useful for projecting high-dimensional data in a low-dimensional space48. t-SNE converts the similarities between data points into probabilities and then minimizes the difference between these probabilities and the actual data. This is a visualization technique where similar objects are represented by nearby points and dissimilar objects are modeled by distant points with high probability. The high-dimensional output from the model is projected onto a two-dimensional plane to visualize the representation spaces using this dimensionality reduction technique. Figure 18 illustrates how the representation space evolves throughout training. The plots display the data at different epochs (0, 10, 20, 30, 40, and 50) using t-SNE to project the high-dimensional representations onto a two-dimensional space. We describe the details in the following.

  • Epoch 0: At the start, the clusters of data overlap, which means the model has figured out how to separate different classes effectively.

  • Epoch 10: As training progresses, the clusters start to become more distinct, but there is still some overlap.

  • Epoch 20: The separation between clusters keeps improving, showing that the model is starting to learn meaningful representations.

  • Epoch 30: The clusters become even more defined, with less overlap between different classes.

  • Epoch 40: The representation space shows well-separated clusters, indicating that the model is learning effectively and doing a better job at separating classes.

  • Epoch 50: Finally, at the last epoch, each cluster is clearly defined with minimal overlap. This demonstrates that the model has successfully learned how to represent the data. At this point, the individual VICReg loss and DB index for each layer converge at similar points.

Fig. 18

Evolution of the final representation space during training of the MNIST model, shown at different epochs: a Epoch 0, b Epoch 10, c Epoch 20, d Epoch 30, e Epoch 40, and f Epoch 50.

link

Exit mobile version