Basic local search
This commit is contained in:
parent
bcf22269e8
commit
6cb016d5e9
2 changed files with 104 additions and 1 deletions
|
@ -1,6 +1,65 @@
|
|||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.function import Function
|
||||
from src.functions import available_functions
|
||||
from src.types import INPUT_VECTOR
|
||||
from src.utils import generate_bounded_points
|
||||
|
||||
|
||||
def search(
|
||||
function: Function,
|
||||
x0: INPUT_VECTOR,
|
||||
iterations: int,
|
||||
neighbors_count: int,
|
||||
std_dev: float,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> Iterator[INPUT_VECTOR]:
|
||||
"""Search for the minimum value of the function using the local search algorithm.
|
||||
|
||||
On each iteration, N neighboring points will be generated around the starting point. These points
|
||||
will then be evaluated on the function, finding the smallest one. This smallest point will become
|
||||
the new starting point, repeating until we run out of iterations.
|
||||
|
||||
Params:
|
||||
x0: Starting point (N-dimensional).
|
||||
iterations: Maximum number of iterations.
|
||||
include_origin: When searching for the next minimum, should the origin point be checked too?
|
||||
neighbors_count: The amount of neighbor points.
|
||||
std_dev: Standard deviation for the normal distribution for neighbor generating.
|
||||
rng: Random generator instance (None for a new rng).
|
||||
|
||||
Yields:
|
||||
Minimum input vector (x) found so far, yielded from each iteration.
|
||||
"""
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
x_center = x0
|
||||
|
||||
for _ in range(iterations):
|
||||
y_min = function(x_center)
|
||||
for point in generate_bounded_points(x_center, neighbors_count, std_dev, function.definition_interval, rng):
|
||||
y_point = function(point)
|
||||
if y_point < y_min:
|
||||
y_min = y_point
|
||||
x_center = point
|
||||
|
||||
yield x_center
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Program entrypoint."""
|
||||
print("Hello from task1!")
|
||||
dims = 3
|
||||
iters = 100
|
||||
neighbors = 10
|
||||
stddev = 1
|
||||
|
||||
for func_name, function in available_functions.items():
|
||||
print(func_name)
|
||||
for x in search(function, function.definition_interval.random_point(dims), iters, neighbors, stddev):
|
||||
y = function(x)
|
||||
print(f"{x} -> {y}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
44
src/utils.py
Normal file
44
src/utils.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import numpy as np
|
||||
|
||||
from src.function import Interval
|
||||
from src.types import INPUT_VECTOR
|
||||
|
||||
|
||||
def generate_bounded_points(
|
||||
center: INPUT_VECTOR,
|
||||
num_points: int,
|
||||
std_dev: float,
|
||||
interval: Interval,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> list[INPUT_VECTOR]:
|
||||
"""Generate `num_points` random points around `center`, following a normal distribution.
|
||||
|
||||
Args:
|
||||
center: The N-dimensional center point.
|
||||
num_points: Number of points to generate.
|
||||
std_dev: Standard deviation for the normal distribution.
|
||||
interval: Interval, determining the min/max values for each dimension.
|
||||
rng: Random generator instance (None for a new rng)
|
||||
|
||||
Returns:
|
||||
A list of generated points.
|
||||
"""
|
||||
dimensions = center.shape[0]
|
||||
valid_points: list[INPUT_VECTOR] = []
|
||||
|
||||
if rng is None:
|
||||
rng = np.random.default_rng()
|
||||
|
||||
while len(valid_points) < num_points:
|
||||
# Generate a batch of points (larger batch to reduce iterations)
|
||||
batch_size = (num_points - len(valid_points)) * 2 # Oversampling to reduce loop iterations
|
||||
candidates = rng.normal(loc=center, scale=std_dev, size=(batch_size, dimensions))
|
||||
|
||||
# Keep only those within bounds
|
||||
mask = (candidates >= interval.min) & (candidates <= interval.max)
|
||||
valid_candidates = candidates[np.all(mask, axis=1)]
|
||||
|
||||
# Add valid points to the list
|
||||
valid_points.extend(valid_candidates[: num_points - len(valid_points)])
|
||||
|
||||
return valid_points
|
Loading…
Add table
Reference in a new issue