Add logic to save the result graphs

This commit is contained in:
Peter Vacho 2024-12-09 13:40:11 +01:00
parent 8b8d882f1c
commit 0e0cc48619
Signed by: school
GPG key ID: 8CFC3837052871B4

View file

@ -2,6 +2,7 @@ import os
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from time import perf_counter
from typing import cast, final
@ -20,6 +21,8 @@ from sklearn.utils import Bunch
# Set to -1 to disable downsizing
DOWNSIZE_MNIST: int = 8000
OUTPUT_DIR = Path.cwd() / "output"
# pyqt6 only bundles Windows & Fusion styles, which means that if you use a
# different preferred qt style, a warning would be produced. This gets rid
# of that warning and removes the env-var override.
@ -111,11 +114,22 @@ def knn_accuracy(data: MLDataset, k_range: range | None = None) -> tuple[object,
return best_k, acc
def plot_2d(x: DataFrame, y: "Series[str]", title: str) -> None:
def plot_2d(
x: DataFrame,
y: "Series[str]",
title: str,
*,
show_plot: bool = False,
save_plot: Path | None = None,
) -> None:
"""Show a 2D visualization of the given 2D (dim-reduced) dataset."""
plt.figure(figsize=(8, 6))
sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15)
plt.title(title)
if save_plot:
plt.savefig(save_plot)
if show_plot:
plt.show()
@ -132,6 +146,8 @@ def timed(start_msg: str, end_msg: str) -> Iterator[None]:
def main() -> None:
"""Program entrypoint."""
OUTPUT_DIR.mkdir(exist_ok=True)
with timed("Loading the MNIST dataset", "MNIST dataset loaded"):
# Working with the entire dataset would be way too computationally expensive
# (TSNE would take hours, if not more), instead, downsize the dataset and work
@ -165,8 +181,20 @@ def main() -> None:
print()
with timed("Showing graphs", "Finished"):
plot_2d(mnist_pca.x_train, mnist_pca.y_train, "2D PCA of MNIST")
plot_2d(mnist_tsne.x_train, mnist_tsne.y_train, "2D t-SNE of MNIST")
plot_2d(
mnist_pca.x_train,
mnist_pca.y_train,
"2D PCA of MNIST",
show_plot=True,
save_plot=OUTPUT_DIR.joinpath("pca.png"),
)
plot_2d(
mnist_tsne.x_train,
mnist_tsne.y_train,
"2D t-SNE of MNIST",
show_plot=True,
save_plot=OUTPUT_DIR.joinpath("tsne.png"),
)
if __name__ == "__main__":