Add logic to save the result graphs

This commit is contained in:
Peter Vacho 2024-12-09 13:40:11 +01:00
parent 8b8d882f1c
commit 0e0cc48619
Signed by: school
GPG key ID: 8CFC3837052871B4

View file

@ -2,6 +2,7 @@ import os
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from time import perf_counter from time import perf_counter
from typing import cast, final from typing import cast, final
@ -20,6 +21,8 @@ from sklearn.utils import Bunch
# Set to -1 to disable downsizing # Set to -1 to disable downsizing
DOWNSIZE_MNIST: int = 8000 DOWNSIZE_MNIST: int = 8000
OUTPUT_DIR = Path.cwd() / "output"
# pyqt6 only bundles Windows & Fusion styles, which means that if you use a # pyqt6 only bundles Windows & Fusion styles, which means that if you use a
# different preferred qt style, a warning would be produced. This gets rid # different preferred qt style, a warning would be produced. This gets rid
# of that warning and removes the env-var override. # of that warning and removes the env-var override.
@ -111,11 +114,22 @@ def knn_accuracy(data: MLDataset, k_range: range | None = None) -> tuple[object,
return best_k, acc return best_k, acc
def plot_2d(x: DataFrame, y: "Series[str]", title: str) -> None: def plot_2d(
x: DataFrame,
y: "Series[str]",
title: str,
*,
show_plot: bool = False,
save_plot: Path | None = None,
) -> None:
"""Show a 2D visualization of the given 2D (dim-reduced) dataset.""" """Show a 2D visualization of the given 2D (dim-reduced) dataset."""
plt.figure(figsize=(8, 6)) plt.figure(figsize=(8, 6))
sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15) sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15)
plt.title(title) plt.title(title)
if save_plot:
plt.savefig(save_plot)
if show_plot:
plt.show() plt.show()
@ -132,6 +146,8 @@ def timed(start_msg: str, end_msg: str) -> Iterator[None]:
def main() -> None: def main() -> None:
"""Program entrypoint.""" """Program entrypoint."""
OUTPUT_DIR.mkdir(exist_ok=True)
with timed("Loading the MNIST dataset", "MNIST dataset loaded"): with timed("Loading the MNIST dataset", "MNIST dataset loaded"):
# Working with the entire dataset would be way too computationally expensive # Working with the entire dataset would be way too computationally expensive
# (TSNE would take hours, if not more), instead, downsize the dataset and work # (TSNE would take hours, if not more), instead, downsize the dataset and work
@ -165,8 +181,20 @@ def main() -> None:
print() print()
with timed("Showing graphs", "Finished"): with timed("Showing graphs", "Finished"):
plot_2d(mnist_pca.x_train, mnist_pca.y_train, "2D PCA of MNIST") plot_2d(
plot_2d(mnist_tsne.x_train, mnist_tsne.y_train, "2D t-SNE of MNIST") mnist_pca.x_train,
mnist_pca.y_train,
"2D PCA of MNIST",
show_plot=True,
save_plot=OUTPUT_DIR.joinpath("pca.png"),
)
plot_2d(
mnist_tsne.x_train,
mnist_tsne.y_train,
"2D t-SNE of MNIST",
show_plot=True,
save_plot=OUTPUT_DIR.joinpath("tsne.png"),
)
if __name__ == "__main__": if __name__ == "__main__":