From 0e0cc486191571322570bc3b3d6b415463cf2fc4 Mon Sep 17 00:00:00 2001 From: Peter Vacho Date: Mon, 9 Dec 2024 13:40:11 +0100 Subject: [PATCH] Add logic to save the result graphs --- src/__main__.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/__main__.py b/src/__main__.py index 5337c71..ff9504e 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -2,6 +2,7 @@ import os from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass +from pathlib import Path from time import perf_counter from typing import cast, final @@ -20,6 +21,8 @@ from sklearn.utils import Bunch # Set to -1 to disable downsizing DOWNSIZE_MNIST: int = 8000 +OUTPUT_DIR = Path.cwd() / "output" + # pyqt6 only bundles Windows & Fusion styles, which means that if you use a # different preferred qt style, a warning would be produced. This gets rid # of that warning and removes the env-var override. @@ -111,12 +114,23 @@ def knn_accuracy(data: MLDataset, k_range: range | None = None) -> tuple[object, return best_k, acc -def plot_2d(x: DataFrame, y: "Series[str]", title: str) -> None: +def plot_2d( + x: DataFrame, + y: "Series[str]", + title: str, + *, + show_plot: bool = False, + save_plot: Path | None = None, +) -> None: """Show a 2D visualization of the given 2D (dim-reduced) dataset.""" plt.figure(figsize=(8, 6)) sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15) plt.title(title) - plt.show() + + if save_plot: + plt.savefig(save_plot) + if show_plot: + plt.show() @contextmanager @@ -132,6 +146,8 @@ def timed(start_msg: str, end_msg: str) -> Iterator[None]: def main() -> None: """Program entrypoint.""" + OUTPUT_DIR.mkdir(exist_ok=True) + with timed("Loading the MNIST dataset", "MNIST dataset loaded"): # Working with the entire dataset would be way too computationally expensive # (TSNE would take hours, if not more), instead, downsize the dataset and work @@ -165,8 +181,20 @@ def main() -> None: print() with timed("Showing graphs", "Finished"): - plot_2d(mnist_pca.x_train, mnist_pca.y_train, "2D PCA of MNIST") - plot_2d(mnist_tsne.x_train, mnist_tsne.y_train, "2D t-SNE of MNIST") + plot_2d( + mnist_pca.x_train, + mnist_pca.y_train, + "2D PCA of MNIST", + show_plot=True, + save_plot=OUTPUT_DIR.joinpath("pca.png"), + ) + plot_2d( + mnist_tsne.x_train, + mnist_tsne.y_train, + "2D t-SNE of MNIST", + show_plot=True, + save_plot=OUTPUT_DIR.joinpath("tsne.png"), + ) if __name__ == "__main__":