Add logic to save the result graphs
This commit is contained in:
parent
8b8d882f1c
commit
0e0cc48619
|
@ -2,6 +2,7 @@ import os
|
|||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import cast, final
|
||||
|
||||
|
@ -20,6 +21,8 @@ from sklearn.utils import Bunch
|
|||
# Set to -1 to disable downsizing
|
||||
DOWNSIZE_MNIST: int = 8000
|
||||
|
||||
OUTPUT_DIR = Path.cwd() / "output"
|
||||
|
||||
# pyqt6 only bundles Windows & Fusion styles, which means that if you use a
|
||||
# different preferred qt style, a warning would be produced. This gets rid
|
||||
# of that warning and removes the env-var override.
|
||||
|
@ -111,12 +114,23 @@ def knn_accuracy(data: MLDataset, k_range: range | None = None) -> tuple[object,
|
|||
return best_k, acc
|
||||
|
||||
|
||||
def plot_2d(x: DataFrame, y: "Series[str]", title: str) -> None:
|
||||
def plot_2d(
|
||||
x: DataFrame,
|
||||
y: "Series[str]",
|
||||
title: str,
|
||||
*,
|
||||
show_plot: bool = False,
|
||||
save_plot: Path | None = None,
|
||||
) -> None:
|
||||
"""Show a 2D visualization of the given 2D (dim-reduced) dataset."""
|
||||
plt.figure(figsize=(8, 6))
|
||||
sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15)
|
||||
plt.title(title)
|
||||
plt.show()
|
||||
|
||||
if save_plot:
|
||||
plt.savefig(save_plot)
|
||||
if show_plot:
|
||||
plt.show()
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -132,6 +146,8 @@ def timed(start_msg: str, end_msg: str) -> Iterator[None]:
|
|||
|
||||
def main() -> None:
|
||||
"""Program entrypoint."""
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
with timed("Loading the MNIST dataset", "MNIST dataset loaded"):
|
||||
# Working with the entire dataset would be way too computationally expensive
|
||||
# (TSNE would take hours, if not more), instead, downsize the dataset and work
|
||||
|
@ -165,8 +181,20 @@ def main() -> None:
|
|||
print()
|
||||
|
||||
with timed("Showing graphs", "Finished"):
|
||||
plot_2d(mnist_pca.x_train, mnist_pca.y_train, "2D PCA of MNIST")
|
||||
plot_2d(mnist_tsne.x_train, mnist_tsne.y_train, "2D t-SNE of MNIST")
|
||||
plot_2d(
|
||||
mnist_pca.x_train,
|
||||
mnist_pca.y_train,
|
||||
"2D PCA of MNIST",
|
||||
show_plot=True,
|
||||
save_plot=OUTPUT_DIR.joinpath("pca.png"),
|
||||
)
|
||||
plot_2d(
|
||||
mnist_tsne.x_train,
|
||||
mnist_tsne.y_train,
|
||||
"2D t-SNE of MNIST",
|
||||
show_plot=True,
|
||||
save_plot=OUTPUT_DIR.joinpath("tsne.png"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue