diff --git a/pyproject.toml b/pyproject.toml index 518d7d8..b0ee238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ authors = [ requires-python = ">=3.12" dependencies = [ "numpy>=2.2.3", + "pyqt6>=6.8.1", ] [dependency-groups] diff --git a/src/__main__.py b/src/__main__.py index 9f31ca7..23914c4 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,26 +1,24 @@ -from src.functions import available_functions -from src.search import min_search +import os +import sys + +from PyQt6.QtWidgets import QApplication + +from src.ui import OptimizationUI def main() -> None: """Program entrypoint.""" - dims = 3 - iters = 100 - neighbors = 10 - stddev = 1 + # The pyqt6 python package only ships Windows and Fusion styles, + # ignore any style overrides unless it's one of these two to avoid warnings. + # Default to Fusion in this case. + qt_style = os.environ.get("QT_STYLE_OVERRIDE") + if qt_style not in {"Windows", "Fusion"}: + os.environ["QT_STYLE_OVERRIDE"] = "Fusion" - for func_name, function in available_functions.items(): - print("-" * 80) - print(func_name) - for x in min_search( - function, - function.definition_interval.random_point(dims), - iterations=iters, - neighbors_count=neighbors, - std_dev=stddev, - ): - y = function(x) - print(f"{x} -> {y}") + app = QApplication(sys.argv) + window = OptimizationUI() + window.show() + sys.exit(app.exec()) if __name__ == "__main__": diff --git a/src/ui.py b/src/ui.py new file mode 100644 index 0000000..8b8fe43 --- /dev/null +++ b/src/ui.py @@ -0,0 +1,157 @@ +from typing import TypedDict + +import numpy as np +from PyQt6.QtCore import Qt +from PyQt6.QtWidgets import ( + QButtonGroup, + QComboBox, + QLabel, + QLineEdit, + QPushButton, + QRadioButton, + QSlider, + QVBoxLayout, + QWidget, +) + +from src.function import Function +from src.functions import available_functions +from src.search import min_search +from src.types import INPUT_VECTOR + + +class OptimizationSettings(TypedDict): + """Dictionary that holds the settings chosen for the optimization algorithm.""" + + function: Function + x0: INPUT_VECTOR + iterations: int + neighbors_count: int + std_dev: float + include_center: bool + rng: np.random.Generator | None + + +class OptimizationUI(QWidget): + """Class that represents the UI from which the user can control the optimization parameters.""" + + def __init__(self) -> None: + super().__init__() + self.setWindowTitle("Optimization Algorithm Settings") + self.setGeometry(100, 100, 400, 300) + + layout = QVBoxLayout() + + # Function Selection Dropdown + self.function_label = QLabel("Select Function:") + self.function_dropdown = QComboBox() + self.function_dropdown.addItems(available_functions.keys()) + + # Iterations Slider + self.iter_label = QLabel("Iterations: 1") + self.iter_slider = QSlider(Qt.Orientation.Horizontal) + self.iter_slider.setRange(1, 1000) + _ = self.iter_slider.valueChanged.connect( + lambda: self.iter_label.setText(f"Iterations: {self.iter_slider.value()}") + ) + + # StdDev Slider + self.stddev_label = QLabel("StdDev: 0") + self.stddev_slider = QSlider(Qt.Orientation.Horizontal) + self.stddev_slider.setRange(0, 30) + _ = self.stddev_slider.valueChanged.connect( + lambda: self.stddev_label.setText(f"StdDev: {self.stddev_slider.value()}") + ) + + # Neighbors Count + self.neighbors_label = QLabel("Neighbors: 1") + self.neighbors_slider = QSlider(Qt.Orientation.Horizontal) + self.neighbors_slider.setRange(1, 100) + _ = self.neighbors_slider.valueChanged.connect( + lambda: self.neighbors_label.setText(f"Neighbors: {self.neighbors_slider.value()}") + ) + + # Dimensions Count + self.dimensions_label = QLabel("Dimensions: 2") + self.dimensions_slider = QSlider(Qt.Orientation.Horizontal) + self.dimensions_slider.setRange(2, 100) + _ = self.dimensions_slider.valueChanged.connect( + lambda: self.dimensions_label.setText(f"Dimensions: {self.dimensions_slider.value()}") + ) + + # Algorithm Selection (Radio Buttons) + self.radio_group = QButtonGroup(self) + self.local_search_radio = QRadioButton("Local Search (Include center points)") + self.hill_climb_radio = QRadioButton("Stochastic Hill Climber (Only new points)") + self.radio_group.addButton(self.local_search_radio) + self.radio_group.addButton(self.hill_climb_radio) + self.local_search_radio.setChecked(True) + + # RNG Seed Input + self.seed_label = QLabel("RNG Seed:") + self.seed_input = QLineEdit() + self.seed_input.setPlaceholderText("Enter seed (optional)") + + # Run Button + self.run_button = QPushButton("Run") + _ = self.run_button.clicked.connect(self.run_algorithm) + + # Adding widgets to layout + layout.addWidget(self.function_label) + layout.addWidget(self.function_dropdown) + layout.addWidget(self.iter_label) + layout.addWidget(self.iter_slider) + layout.addWidget(self.stddev_label) + layout.addWidget(self.stddev_slider) + layout.addWidget(self.neighbors_label) + layout.addWidget(self.neighbors_slider) + layout.addWidget(self.dimensions_label) + layout.addWidget(self.dimensions_slider) + layout.addWidget(self.local_search_radio) + layout.addWidget(self.hill_climb_radio) + layout.addWidget(self.seed_label) + layout.addWidget(self.seed_input) + layout.addWidget(self.run_button) + + self.setLayout(layout) + + @property + def current_settings(self) -> OptimizationSettings: + """Get currently selected settings for the optimization algorithm.""" + func_name = self.function_dropdown.currentText() + function = available_functions[func_name] + + seed = self.seed_input.text() + if seed != "": + bit_gen = np.random.PCG64(abs(hash(seed))) + rng = np.random.Generator(bit_gen) + else: + rng = None + + dims = self.dimensions_slider.value() + x0 = function.definition_interval.random_point(dims, rng=rng) + + return OptimizationSettings( + { + "function": function, + "x0": x0, + "iterations": self.iter_slider.value(), + "std_dev": self.stddev_slider.value(), + "neighbors_count": self.neighbors_slider.value(), + "include_center": self.local_search_radio.isChecked(), + "rng": rng, + }, + ) + + def run_algorithm(self) -> None: + """Run the search algorithm according to selected settings. + + This is a call-back for pressing the run button. + """ + settings = self.current_settings + print("-" * 80) + print(self.function_dropdown.currentText()) + print(settings) + for x in min_search(**settings): + y = settings["function"](x) + print(f"{x} -> {y}") diff --git a/uv.lock b/uv.lock index 56cb2d4..38dc38d 100644 --- a/uv.lock +++ b/uv.lock @@ -138,6 +138,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 }, ] +[[package]] +name = "pyqt6" +version = "6.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyqt6-qt6" }, + { name = "pyqt6-sip" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/bf/ff284a136b39cb1873c18e4fca4a40a8847c84a1910c5fb38c6a77868968/pyqt6-6.8.1.tar.gz", hash = "sha256:91d937d6166274fafd70f4dee11a8da6dbfdb0da53de05f5d62361ddf775e256", size = 1064723 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/da/70971b3d7f53a68644ea32544d3786dfbbb162d18572ac1defcf5a6481d5/PyQt6-6.8.1-cp39-abi3-macosx_10_14_universal2.whl", hash = "sha256:0425f9eebdd5d4e57ab36424c9382f2ea06670c3c550fa0028c2b19bd0a1d7bd", size = 12213924 }, + { url = "https://files.pythonhosted.org/packages/be/25/a4392c323a0fb97eb5f449b7594f37e93d9794b900756b43cd65772def77/PyQt6-6.8.1-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:36bf48e3df3a6ff536e703315d155480ef4e260396eb5469eb7a875bc5bb7ab4", size = 8238120 }, + { url = "https://files.pythonhosted.org/packages/de/a3/e528b4cc3394f2ae15b531c17f27b53de756a8c0404dfa9c184502367c48/PyQt6-6.8.1-cp39-abi3-manylinux_2_39_aarch64.whl", hash = "sha256:2eac2267a34828b8db7660dd3cc3b3b5fd76a92e61ad45471565b01221cb558b", size = 8173996 }, + { url = "https://files.pythonhosted.org/packages/f2/69/11404cfcb916bd7207805c21432ecab0401779361d48b67f28ae9337f70d/PyQt6-6.8.1-cp39-abi3-win_amd64.whl", hash = "sha256:70bad7b890a8f9e9e5fb9598c544b832d9d9d99a9519e0009cb29c1e15e96632", size = 6723466 }, + { url = "https://files.pythonhosted.org/packages/00/2a/21a555aea9bc8abc4f09017b922dbdf509c421f70506d4c83d2e8f4315b2/PyQt6-6.8.1-cp39-abi3-win_arm64.whl", hash = "sha256:a40f878e8e5eeeb0bba995152d07eeef9375ea0116df0f4aad0a6b97c8ad1175", size = 5463379 }, +] + +[[package]] +name = "pyqt6-qt6" +version = "6.8.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/a4/3d764e05955382b3dc7227cbfde090700edd63431147f1c66d428ccac45c/PyQt6_Qt6-6.8.2-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:470dd4211fe5a67b0565e0202e7aa67816e5dcf7d713528b88327adaebd0934e", size = 66121240 }, + { url = "https://files.pythonhosted.org/packages/d6/b3/6d4f8257b46554fb2c89b33a6773a3f05ed961b3cd83828caee5dc79899f/PyQt6_Qt6-6.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:40cda901a3e1617e79225c354fe9d89b80249f0a6c6aaa18b40938e05bbf7d1f", size = 60286219 }, + { url = "https://files.pythonhosted.org/packages/92/95/0036435b9e2cbd22e08f14eec2362c32fc641660c6e4aea6f59d165cb5fc/PyQt6_Qt6-6.8.2-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:fb6d0acdd7d43c33fb8b9d2dd7922d381cdedd00da316049fbe01fc1973e6f05", size = 81263397 }, + { url = "https://files.pythonhosted.org/packages/6e/fb/c01dde044eca1542d88cac72fc99369af76a981cc2f52790236efa566e01/PyQt6_Qt6-6.8.2-py3-none-manylinux_2_39_aarch64.whl", hash = "sha256:5970c85d22cbe5c476418994549161b23ed938e25b04fc4ca8fabf6dcac7b03f", size = 79832921 }, + { url = "https://files.pythonhosted.org/packages/1a/f7/31f03a9f5e6c7cc23ceb2bd0d9c2df0518837f7af0e693e15b6e0881b8b0/PyQt6_Qt6-6.8.2-py3-none-win_amd64.whl", hash = "sha256:28e2bb641f05b01e498503c3ef01c8a919d6e0e96b50230301c0baac2b7d1433", size = 71934164 }, + { url = "https://files.pythonhosted.org/packages/00/c9/102c9537795ca11c12120ec9d5f554d9437787f52d8e23fbc8269e6a2699/PyQt6_Qt6-6.8.2-py3-none-win_arm64.whl", hash = "sha256:912afdddd0dfc666ce1c16bc4695e2acd680db72343e4f7a2b7c053a0146b4bc", size = 48120018 }, +] + +[[package]] +name = "pyqt6-sip" +version = "13.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/18/0405c54acba0c8e276dd6f0601890e6e735198218d031a6646104870fe22/pyqt6_sip-13.10.0.tar.gz", hash = "sha256:d6daa95a0bd315d9ec523b549e0ce97455f61ded65d5eafecd83ed2aa4ae5350", size = 92464 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/81/66d9bdacb790592a0641378749a047f12e3b254cdc2cb51f7ed636cf01d2/PyQt6_sip-13.10.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:48791db2914fc39c3218519a02d2a5fd3fcd354a1be3141a57bf2880701486f2", size = 112334 }, + { url = "https://files.pythonhosted.org/packages/26/2c/4796c209009a018e0d4a5c406d5a519234c5a378f370dc679d0ad5f455b2/PyQt6_sip-13.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:466d6b4791973c9fcbdc2e0087ed194b9ea802a8c3948867a849498f0841c70c", size = 322334 }, + { url = "https://files.pythonhosted.org/packages/99/34/2ec54bd475f0a811df1d32be485f2344cf9e8b388ce7adb26b46ce5552d4/PyQt6_sip-13.10.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ae15358941f127cd3d1ab09c1ebd45c4dabb0b2e91587b9eebde0279d0039c54", size = 303798 }, + { url = "https://files.pythonhosted.org/packages/0c/e4/82099bb4ab8bc152b5718541e93c0b3adf7566c0f307c9e58e2368b8c517/PyQt6_sip-13.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:ad573184fa8b00041944e5a17d150ab0d08db2d2189e39c9373574ebab3f2e58", size = 53569 }, + { url = "https://files.pythonhosted.org/packages/e3/09/90e0378887a3cb9664da77061229cf8e97e6ec25a5611b7dbc9cc3e02c78/PyQt6_sip-13.10.0-cp312-cp312-win_arm64.whl", hash = "sha256:2d579d810d0047d40bde9c6aef281d6ed218db93c9496ebc9e55b9e6f27a229d", size = 45430 }, + { url = "https://files.pythonhosted.org/packages/6b/0c/8d1de48b45b565a46bf4757341f13f9b1853a7d2e6b023700f0af2c213ab/PyQt6_sip-13.10.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7b6e250c2e7c14702a623f2cc1479d7fb8db2b6eee9697cac10d06fe79c281bb", size = 112343 }, + { url = "https://files.pythonhosted.org/packages/af/13/e2cc2b667a9f5d44c2d0e18fa6e1066fca3f4521dcb301f4b5374caeb33e/PyQt6_sip-13.10.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fcb30756568f8cd59290f9ef2ae5ee3e72ff9cdd61a6f80c9e3d3b95ae676be", size = 322527 }, + { url = "https://files.pythonhosted.org/packages/20/1a/5c6fcae85edb65cf236c9dc6d23b279b5316e94cdca1abdee6d0a217ddbb/PyQt6_sip-13.10.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:757ac52c92b2ef0b56ecc7cd763b55a62d3c14271d7ea8d03315af85a70090ff", size = 303407 }, + { url = "https://files.pythonhosted.org/packages/b9/db/6924ec985be7d746772806b96ab81d24263ef72f0249f0573a82adaed75e/PyQt6_sip-13.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:571900c44a3e38738d696234d94fe2043972b9de0633505451c99e2922cb6a34", size = 53580 }, + { url = "https://files.pythonhosted.org/packages/77/c3/9e44729b582ee7f1d45160e8c292723156889f3e38ce6574f88d5ab8fa02/PyQt6_sip-13.10.0-cp313-cp313-win_arm64.whl", hash = "sha256:39cba2cc71cf80a99b4dc8147b43508d4716e128f9fb99f5eb5860a37f082282", size = 45446 }, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -195,6 +243,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "numpy" }, + { name = "pyqt6" }, ] [package.dev-dependencies] @@ -205,7 +254,10 @@ lint = [ ] [package.metadata] -requires-dist = [{ name = "numpy", specifier = ">=2.2.3" }] +requires-dist = [ + { name = "numpy", specifier = ">=2.2.3" }, + { name = "pyqt6", specifier = ">=6.8.1" }, +] [package.metadata.requires-dev] lint = [