Add logic to save the result graphs
This commit is contained in:
parent
8b8d882f1c
commit
0e0cc48619
|
@ -2,6 +2,7 @@ import os
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import cast, final
|
from typing import cast, final
|
||||||
|
|
||||||
|
@ -20,6 +21,8 @@ from sklearn.utils import Bunch
|
||||||
# Set to -1 to disable downsizing
|
# Set to -1 to disable downsizing
|
||||||
DOWNSIZE_MNIST: int = 8000
|
DOWNSIZE_MNIST: int = 8000
|
||||||
|
|
||||||
|
OUTPUT_DIR = Path.cwd() / "output"
|
||||||
|
|
||||||
# pyqt6 only bundles Windows & Fusion styles, which means that if you use a
|
# pyqt6 only bundles Windows & Fusion styles, which means that if you use a
|
||||||
# different preferred qt style, a warning would be produced. This gets rid
|
# different preferred qt style, a warning would be produced. This gets rid
|
||||||
# of that warning and removes the env-var override.
|
# of that warning and removes the env-var override.
|
||||||
|
@ -111,12 +114,23 @@ def knn_accuracy(data: MLDataset, k_range: range | None = None) -> tuple[object,
|
||||||
return best_k, acc
|
return best_k, acc
|
||||||
|
|
||||||
|
|
||||||
def plot_2d(x: DataFrame, y: "Series[str]", title: str) -> None:
|
def plot_2d(
|
||||||
|
x: DataFrame,
|
||||||
|
y: "Series[str]",
|
||||||
|
title: str,
|
||||||
|
*,
|
||||||
|
show_plot: bool = False,
|
||||||
|
save_plot: Path | None = None,
|
||||||
|
) -> None:
|
||||||
"""Show a 2D visualization of the given 2D (dim-reduced) dataset."""
|
"""Show a 2D visualization of the given 2D (dim-reduced) dataset."""
|
||||||
plt.figure(figsize=(8, 6))
|
plt.figure(figsize=(8, 6))
|
||||||
sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15)
|
sns.scatterplot(x=x[:, 0], y=x[:, 1], hue=y, palette="tab10", legend="full", s=15)
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
plt.show()
|
|
||||||
|
if save_plot:
|
||||||
|
plt.savefig(save_plot)
|
||||||
|
if show_plot:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -132,6 +146,8 @@ def timed(start_msg: str, end_msg: str) -> Iterator[None]:
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Program entrypoint."""
|
"""Program entrypoint."""
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
with timed("Loading the MNIST dataset", "MNIST dataset loaded"):
|
with timed("Loading the MNIST dataset", "MNIST dataset loaded"):
|
||||||
# Working with the entire dataset would be way too computationally expensive
|
# Working with the entire dataset would be way too computationally expensive
|
||||||
# (TSNE would take hours, if not more), instead, downsize the dataset and work
|
# (TSNE would take hours, if not more), instead, downsize the dataset and work
|
||||||
|
@ -165,8 +181,20 @@ def main() -> None:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
with timed("Showing graphs", "Finished"):
|
with timed("Showing graphs", "Finished"):
|
||||||
plot_2d(mnist_pca.x_train, mnist_pca.y_train, "2D PCA of MNIST")
|
plot_2d(
|
||||||
plot_2d(mnist_tsne.x_train, mnist_tsne.y_train, "2D t-SNE of MNIST")
|
mnist_pca.x_train,
|
||||||
|
mnist_pca.y_train,
|
||||||
|
"2D PCA of MNIST",
|
||||||
|
show_plot=True,
|
||||||
|
save_plot=OUTPUT_DIR.joinpath("pca.png"),
|
||||||
|
)
|
||||||
|
plot_2d(
|
||||||
|
mnist_tsne.x_train,
|
||||||
|
mnist_tsne.y_train,
|
||||||
|
"2D t-SNE of MNIST",
|
||||||
|
show_plot=True,
|
||||||
|
save_plot=OUTPUT_DIR.joinpath("tsne.png"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue