Time dependent Bayesian Optimization¶
In this example we demonstrate time dependent optimization. In this case we are not only interested in finding an optimum point in input space, but also maintain the ideal point over time.
# set values if testing
import os
import time
import warnings
import torch
from matplotlib import pyplot as plt
from tqdm import trange
from xopt.generators.bayesian import TDUpperConfidenceBoundGenerator
from xopt.vocs import VOCS
from xopt.evaluator import Evaluator
from xopt import Xopt
SMOKE_TEST = os.environ.get("SMOKE_TEST")
N_MC_SAMPLES = 1 if SMOKE_TEST else 128
NUM_RESTARTS = 1 if SMOKE_TEST else 20
N_STEPS = 1 if SMOKE_TEST else 250
warnings.filterwarnings("ignore")
Time dependent test problem¶
Optimization is carried out over a single variable x. The test function is a simple
quadratic, with a minimum location that drifts and changes as a function of time t.
Define test functions
# location of time dependent minimum
def k(t_):
return torch.where(
t_ < 50, 0.25 * torch.sin(t_ * 6 / 10.0) + 0.1e-2 * t_, -1.5e-2 * (t_ - 50.0)
)
# define function in time and position space
def g(x_, t_):
return (x_ - k(t_)) ** 2
# create callable function for Xopt
def f(inputs):
x_ = inputs["x"]
current_time = time.time()
t_ = current_time - start_time
y_ = g(x_, torch.tensor(t_))
return {"y": float(y_), "time": float(current_time)}
Define Xopt objects including optimization algorithm¶
variables = {"x": [-1, 1]}
objectives = {"y": "MINIMIZE"}
vocs = VOCS(variables=variables, objectives=objectives)
evaluator = Evaluator(function=f)
Run optimization¶
generator = TDUpperConfidenceBoundGenerator(
vocs=vocs,
beta=0.01,
added_time=0.1,
forgetting_time=10.0,
)
generator.n_monte_carlo_samples = N_MC_SAMPLES
generator.numerical_optimizer.n_restarts = NUM_RESTARTS
generator.max_travel_distances = [0.1]
generator.gp_constructor.use_low_noise_prior = True
start_time = time.time()
X = Xopt(evaluator=evaluator, generator=generator)
X.random_evaluate(2)
for _ in trange(N_STEPS):
# note that in this example we can ignore warnings if computation
# time is greater than added time
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
X.step()
time.sleep(0.1)
0%| | 0/250 [00:00<?, ?it/s]
0%| | 1/250 [00:06<25:17, 6.09s/it]
1%| | 2/250 [00:08<17:20, 4.20s/it]
1%| | 3/250 [00:11<13:20, 3.24s/it]
2%|▏ | 4/250 [00:14<13:17, 3.24s/it]
2%|▏ | 5/250 [00:18<14:28, 3.54s/it]
2%|▏ | 6/250 [00:22<15:00, 3.69s/it]
3%|▎ | 7/250 [00:25<14:25, 3.56s/it]
3%|▎ | 8/250 [00:28<12:50, 3.19s/it]
4%|▎ | 9/250 [00:31<13:26, 3.34s/it]
4%|▍ | 10/250 [00:35<13:39, 3.41s/it]
4%|▍ | 11/250 [00:37<12:24, 3.12s/it]
5%|▍ | 12/250 [00:40<11:20, 2.86s/it]
5%|▌ | 13/250 [00:44<13:01, 3.30s/it]
6%|▌ | 14/250 [00:48<13:26, 3.42s/it]
6%|▌ | 15/250 [00:50<12:07, 3.10s/it]
6%|▋ | 16/250 [00:52<10:26, 2.68s/it]
7%|▋ | 17/250 [00:53<09:04, 2.34s/it]
7%|▋ | 18/250 [00:55<07:57, 2.06s/it]
8%|▊ | 19/250 [00:57<07:53, 2.05s/it]
8%|▊ | 20/250 [00:59<07:47, 2.03s/it]
8%|▊ | 21/250 [01:00<06:35, 1.73s/it]
9%|▉ | 22/250 [01:01<06:20, 1.67s/it]
9%|▉ | 23/250 [01:02<06:00, 1.59s/it]
10%|▉ | 24/250 [01:03<05:15, 1.40s/it]
10%|█ | 25/250 [01:04<04:32, 1.21s/it]
10%|█ | 26/250 [01:05<04:04, 1.09s/it]
11%|█ | 27/250 [01:06<04:22, 1.18s/it]
11%|█ | 28/250 [01:08<04:28, 1.21s/it]
12%|█▏ | 29/250 [01:09<04:29, 1.22s/it]
12%|█▏ | 30/250 [01:10<04:37, 1.26s/it]
12%|█▏ | 31/250 [01:11<04:08, 1.14s/it]
13%|█▎ | 32/250 [01:12<03:49, 1.05s/it]
13%|█▎ | 33/250 [01:13<03:48, 1.05s/it]
14%|█▎ | 34/250 [01:14<03:30, 1.03it/s]
14%|█▍ | 35/250 [01:15<03:44, 1.04s/it]
14%|█▍ | 36/250 [01:17<04:18, 1.21s/it]
15%|█▍ | 37/250 [01:18<04:01, 1.13s/it]
15%|█▌ | 38/250 [01:18<03:10, 1.11it/s]
16%|█▌ | 39/250 [01:19<02:52, 1.22it/s]
16%|█▌ | 40/250 [01:19<02:39, 1.32it/s]
16%|█▋ | 41/250 [01:20<02:59, 1.16it/s]
17%|█▋ | 42/250 [01:21<03:02, 1.14it/s]
17%|█▋ | 43/250 [01:22<02:56, 1.17it/s]
18%|█▊ | 44/250 [01:22<02:31, 1.36it/s]
18%|█▊ | 45/250 [01:23<02:13, 1.53it/s]
18%|█▊ | 46/250 [01:24<02:20, 1.45it/s]
19%|█▉ | 47/250 [01:24<02:08, 1.58it/s]
19%|█▉ | 48/250 [01:25<01:58, 1.71it/s]
20%|█▉ | 49/250 [01:25<01:51, 1.81it/s]
20%|██ | 50/250 [01:26<02:04, 1.61it/s]
20%|██ | 51/250 [01:27<02:10, 1.53it/s]
21%|██ | 52/250 [01:27<02:04, 1.59it/s]
21%|██ | 53/250 [01:28<01:59, 1.65it/s]
22%|██▏ | 54/250 [01:29<02:08, 1.52it/s]
22%|██▏ | 55/250 [01:30<03:20, 1.03s/it]
22%|██▏ | 56/250 [01:32<03:57, 1.22s/it]
23%|██▎ | 57/250 [01:35<05:35, 1.74s/it]
23%|██▎ | 58/250 [01:36<04:47, 1.50s/it]
24%|██▎ | 59/250 [01:37<04:13, 1.33s/it]
24%|██▍ | 60/250 [01:38<03:29, 1.10s/it]
24%|██▍ | 61/250 [01:38<03:03, 1.03it/s]
25%|██▍ | 62/250 [01:39<02:44, 1.14it/s]
25%|██▌ | 63/250 [01:39<02:23, 1.31it/s]
26%|██▌ | 64/250 [01:40<02:24, 1.29it/s]
26%|██▌ | 65/250 [01:41<02:18, 1.33it/s]
26%|██▋ | 66/250 [01:42<02:29, 1.23it/s]
27%|██▋ | 67/250 [01:43<02:27, 1.24it/s]
27%|██▋ | 68/250 [01:43<02:14, 1.36it/s]
28%|██▊ | 69/250 [01:44<02:04, 1.45it/s]
28%|██▊ | 70/250 [01:44<02:01, 1.48it/s]
28%|██▊ | 71/250 [01:45<01:57, 1.52it/s]
29%|██▉ | 72/250 [01:46<02:00, 1.48it/s]
29%|██▉ | 73/250 [01:46<02:04, 1.43it/s]
30%|██▉ | 74/250 [01:47<01:51, 1.58it/s]
30%|███ | 75/250 [01:48<01:51, 1.57it/s]
30%|███ | 76/250 [01:48<02:03, 1.41it/s]
31%|███ | 77/250 [01:49<02:13, 1.30it/s]
31%|███ | 78/250 [01:50<02:06, 1.36it/s]
32%|███▏ | 79/250 [01:51<02:01, 1.41it/s]
32%|███▏ | 80/250 [01:52<02:28, 1.14it/s]
32%|███▏ | 81/250 [01:52<02:08, 1.31it/s]
33%|███▎ | 82/250 [01:53<02:10, 1.29it/s]
33%|███▎ | 83/250 [01:54<02:14, 1.24it/s]
34%|███▎ | 84/250 [01:55<02:04, 1.34it/s]
34%|███▍ | 85/250 [01:55<02:03, 1.34it/s]
34%|███▍ | 86/250 [01:56<02:11, 1.25it/s]
35%|███▍ | 87/250 [01:58<02:52, 1.06s/it]
35%|███▌ | 88/250 [01:59<02:47, 1.03s/it]
36%|███▌ | 89/250 [02:01<03:07, 1.16s/it]
36%|███▌ | 90/250 [02:01<02:41, 1.01s/it]
36%|███▋ | 91/250 [02:02<02:17, 1.16it/s]
37%|███▋ | 92/250 [02:02<02:01, 1.30it/s]
37%|███▋ | 93/250 [02:03<01:50, 1.43it/s]
38%|███▊ | 94/250 [02:04<01:51, 1.39it/s]
38%|███▊ | 95/250 [02:04<01:51, 1.39it/s]
38%|███▊ | 96/250 [02:05<01:55, 1.34it/s]
39%|███▉ | 97/250 [02:06<01:54, 1.34it/s]
39%|███▉ | 98/250 [02:07<02:04, 1.22it/s]
40%|███▉ | 99/250 [02:08<02:30, 1.00it/s]
40%|████ | 100/250 [02:09<02:08, 1.16it/s]
40%|████ | 101/250 [02:09<01:58, 1.26it/s]
41%|████ | 102/250 [02:10<01:49, 1.35it/s]
41%|████ | 103/250 [02:11<01:48, 1.36it/s]
42%|████▏ | 104/250 [02:11<01:39, 1.47it/s]
42%|████▏ | 105/250 [02:12<01:46, 1.37it/s]
42%|████▏ | 106/250 [02:13<01:40, 1.43it/s]
43%|████▎ | 107/250 [02:13<01:36, 1.49it/s]
43%|████▎ | 108/250 [02:14<01:27, 1.62it/s]
44%|████▎ | 109/250 [02:15<01:55, 1.22it/s]
44%|████▍ | 110/250 [02:16<02:01, 1.16it/s]
44%|████▍ | 111/250 [02:17<01:51, 1.25it/s]
45%|████▍ | 112/250 [02:17<01:40, 1.38it/s]
45%|████▌ | 113/250 [02:18<01:34, 1.45it/s]
46%|████▌ | 114/250 [02:19<01:40, 1.36it/s]
46%|████▌ | 115/250 [02:19<01:29, 1.51it/s]
46%|████▋ | 116/250 [02:20<01:22, 1.62it/s]
47%|████▋ | 117/250 [02:21<01:29, 1.49it/s]
47%|████▋ | 118/250 [02:21<01:30, 1.46it/s]
48%|████▊ | 119/250 [02:22<01:26, 1.52it/s]
48%|████▊ | 120/250 [02:23<01:31, 1.42it/s]
48%|████▊ | 121/250 [02:23<01:24, 1.53it/s]
49%|████▉ | 122/250 [02:24<01:23, 1.54it/s]
49%|████▉ | 123/250 [02:24<01:19, 1.60it/s]
50%|████▉ | 124/250 [02:25<01:25, 1.48it/s]
50%|█████ | 125/250 [02:26<01:19, 1.56it/s]
50%|█████ | 126/250 [02:26<01:15, 1.64it/s]
51%|█████ | 127/250 [02:27<01:15, 1.64it/s]
51%|█████ | 128/250 [02:27<01:12, 1.69it/s]
52%|█████▏ | 129/250 [02:28<01:05, 1.84it/s]
52%|█████▏ | 130/250 [02:29<01:09, 1.73it/s]
52%|█████▏ | 131/250 [02:30<01:43, 1.15it/s]
53%|█████▎ | 132/250 [02:31<01:33, 1.26it/s]
53%|█████▎ | 133/250 [02:31<01:24, 1.39it/s]
54%|█████▎ | 134/250 [02:32<01:21, 1.43it/s]
54%|█████▍ | 135/250 [02:33<01:17, 1.48it/s]
54%|█████▍ | 136/250 [02:33<01:18, 1.45it/s]
55%|█████▍ | 137/250 [02:34<01:13, 1.54it/s]
55%|█████▌ | 138/250 [02:34<01:04, 1.74it/s]
56%|█████▌ | 139/250 [02:35<01:00, 1.84it/s]
56%|█████▌ | 140/250 [02:35<00:58, 1.88it/s]
56%|█████▋ | 141/250 [02:36<01:01, 1.77it/s]
57%|█████▋ | 142/250 [02:37<01:03, 1.69it/s]
57%|█████▋ | 143/250 [02:37<01:07, 1.59it/s]
58%|█████▊ | 144/250 [02:38<01:06, 1.59it/s]
58%|█████▊ | 145/250 [02:38<01:05, 1.60it/s]
58%|█████▊ | 146/250 [02:39<00:59, 1.75it/s]
59%|█████▉ | 147/250 [02:39<00:56, 1.83it/s]
59%|█████▉ | 148/250 [02:40<01:03, 1.60it/s]
60%|█████▉ | 149/250 [02:41<01:06, 1.52it/s]
60%|██████ | 150/250 [02:41<01:01, 1.62it/s]
60%|██████ | 151/250 [02:42<00:56, 1.76it/s]
61%|██████ | 152/250 [02:42<00:54, 1.80it/s]
61%|██████ | 153/250 [02:43<00:56, 1.73it/s]
62%|██████▏ | 154/250 [02:44<01:01, 1.55it/s]
62%|██████▏ | 155/250 [02:45<01:01, 1.55it/s]
62%|██████▏ | 156/250 [02:45<01:01, 1.54it/s]
63%|██████▎ | 157/250 [02:46<01:01, 1.50it/s]
63%|██████▎ | 158/250 [02:46<00:55, 1.65it/s]
64%|██████▎ | 159/250 [02:47<00:51, 1.77it/s]
64%|██████▍ | 160/250 [02:47<00:50, 1.79it/s]
64%|██████▍ | 161/250 [02:48<00:53, 1.68it/s]
65%|██████▍ | 162/250 [02:49<00:57, 1.53it/s]
65%|██████▌ | 163/250 [02:49<00:55, 1.56it/s]
66%|██████▌ | 164/250 [02:50<00:51, 1.67it/s]
66%|██████▌ | 165/250 [02:50<00:49, 1.72it/s]
66%|██████▋ | 166/250 [02:51<00:49, 1.71it/s]
67%|██████▋ | 167/250 [02:52<00:49, 1.67it/s]
67%|██████▋ | 168/250 [02:52<00:44, 1.83it/s]
68%|██████▊ | 169/250 [02:53<00:46, 1.75it/s]
68%|██████▊ | 170/250 [02:53<00:46, 1.72it/s]
68%|██████▊ | 171/250 [02:54<00:45, 1.73it/s]
69%|██████▉ | 172/250 [02:55<00:48, 1.62it/s]
69%|██████▉ | 173/250 [02:55<00:42, 1.83it/s]
70%|██████▉ | 174/250 [02:56<00:44, 1.71it/s]
70%|███████ | 175/250 [02:56<00:47, 1.58it/s]
70%|███████ | 176/250 [02:57<00:48, 1.52it/s]
71%|███████ | 177/250 [02:58<00:51, 1.42it/s]
71%|███████ | 178/250 [02:59<00:49, 1.45it/s]
72%|███████▏ | 179/250 [03:00<00:55, 1.28it/s]
72%|███████▏ | 180/250 [03:00<00:49, 1.41it/s]
72%|███████▏ | 181/250 [03:01<00:45, 1.51it/s]
73%|███████▎ | 182/250 [03:01<00:42, 1.61it/s]
73%|███████▎ | 183/250 [03:02<00:42, 1.59it/s]
74%|███████▎ | 184/250 [03:02<00:39, 1.68it/s]
74%|███████▍ | 185/250 [03:03<00:42, 1.53it/s]
74%|███████▍ | 186/250 [03:04<00:41, 1.55it/s]
75%|███████▍ | 187/250 [03:04<00:36, 1.74it/s]
75%|███████▌ | 188/250 [03:05<00:35, 1.76it/s]
76%|███████▌ | 189/250 [03:06<00:39, 1.55it/s]
76%|███████▌ | 190/250 [03:06<00:36, 1.66it/s]
76%|███████▋ | 191/250 [03:07<00:41, 1.42it/s]
77%|███████▋ | 192/250 [03:08<00:37, 1.55it/s]
77%|███████▋ | 193/250 [03:08<00:36, 1.55it/s]
78%|███████▊ | 194/250 [03:09<00:38, 1.47it/s]
78%|███████▊ | 195/250 [03:10<00:38, 1.44it/s]
78%|███████▊ | 196/250 [03:11<00:39, 1.37it/s]
79%|███████▉ | 197/250 [03:11<00:37, 1.42it/s]
79%|███████▉ | 198/250 [03:12<00:33, 1.54it/s]
80%|███████▉ | 199/250 [03:12<00:33, 1.51it/s]
80%|████████ | 200/250 [03:13<00:31, 1.57it/s]
80%|████████ | 201/250 [03:14<00:33, 1.47it/s]
81%|████████ | 202/250 [03:15<00:33, 1.42it/s]
81%|████████ | 203/250 [03:15<00:30, 1.57it/s]
82%|████████▏ | 204/250 [03:15<00:26, 1.72it/s]
82%|████████▏ | 205/250 [03:16<00:26, 1.67it/s]
82%|████████▏ | 206/250 [03:17<00:27, 1.61it/s]
83%|████████▎ | 207/250 [03:18<00:29, 1.48it/s]
83%|████████▎ | 208/250 [03:18<00:26, 1.57it/s]
84%|████████▎ | 209/250 [03:19<00:23, 1.72it/s]
84%|████████▍ | 210/250 [03:19<00:24, 1.67it/s]
84%|████████▍ | 211/250 [03:20<00:24, 1.60it/s]
85%|████████▍ | 212/250 [03:21<00:25, 1.51it/s]
85%|████████▌ | 213/250 [03:22<00:26, 1.37it/s]
86%|████████▌ | 214/250 [03:22<00:26, 1.38it/s]
86%|████████▌ | 215/250 [03:23<00:25, 1.39it/s]
86%|████████▋ | 216/250 [03:24<00:24, 1.39it/s]
87%|████████▋ | 217/250 [03:24<00:22, 1.45it/s]
87%|████████▋ | 218/250 [03:25<00:21, 1.52it/s]
88%|████████▊ | 219/250 [03:25<00:20, 1.54it/s]
88%|████████▊ | 220/250 [03:26<00:17, 1.74it/s]
88%|████████▊ | 221/250 [03:26<00:16, 1.73it/s]
89%|████████▉ | 222/250 [03:27<00:15, 1.79it/s]
89%|████████▉ | 223/250 [03:28<00:17, 1.51it/s]
90%|████████▉ | 224/250 [03:28<00:15, 1.71it/s]
90%|█████████ | 225/250 [03:29<00:16, 1.50it/s]
90%|█████████ | 226/250 [03:30<00:16, 1.48it/s]
91%|█████████ | 227/250 [03:30<00:14, 1.54it/s]
91%|█████████ | 228/250 [03:31<00:13, 1.64it/s]
92%|█████████▏| 229/250 [03:31<00:11, 1.76it/s]
92%|█████████▏| 230/250 [03:32<00:11, 1.72it/s]
92%|█████████▏| 231/250 [03:33<00:11, 1.64it/s]
93%|█████████▎| 232/250 [03:33<00:11, 1.55it/s]
93%|█████████▎| 233/250 [03:34<00:10, 1.65it/s]
94%|█████████▎| 234/250 [03:35<00:12, 1.26it/s]
94%|█████████▍| 235/250 [03:36<00:10, 1.39it/s]
94%|█████████▍| 236/250 [03:37<00:10, 1.36it/s]
95%|█████████▍| 237/250 [03:37<00:08, 1.49it/s]
95%|█████████▌| 238/250 [03:37<00:07, 1.68it/s]
96%|█████████▌| 239/250 [03:38<00:05, 1.83it/s]
96%|█████████▌| 240/250 [03:38<00:05, 1.84it/s]
96%|█████████▋| 241/250 [03:39<00:04, 1.84it/s]
97%|█████████▋| 242/250 [03:39<00:04, 1.98it/s]
97%|█████████▋| 243/250 [03:40<00:03, 1.99it/s]
98%|█████████▊| 244/250 [03:40<00:03, 1.96it/s]
98%|█████████▊| 245/250 [03:41<00:02, 1.98it/s]
98%|█████████▊| 246/250 [03:41<00:02, 1.93it/s]
99%|█████████▉| 247/250 [03:42<00:01, 2.03it/s]
99%|█████████▉| 248/250 [03:43<00:01, 1.60it/s]
100%|█████████▉| 249/250 [03:43<00:00, 1.57it/s]
100%|██████████| 250/250 [03:44<00:00, 1.60it/s]
100%|██████████| 250/250 [03:44<00:00, 1.11it/s]
Visualize GP model of objective function and plot trajectory¶
data = X.data
xbounds = generator.vocs.bounds
tbounds = [data["time"].min(), data["time"].max()]
model = X.generator.model
n = 100
t = torch.linspace(*tbounds, n, dtype=torch.double)
x = torch.linspace(*torch.tensor(xbounds).flatten(), n, dtype=torch.double)
tt, xx = torch.meshgrid(t, x)
pts = torch.hstack([ele.reshape(-1, 1) for ele in (tt, xx)]).double()
tt, xx = tt.numpy(), xx.numpy()
# NOTE: the model inputs are such that t is the last dimension
gp_pts = torch.flip(pts, dims=[-1])
gt_vals = g(gp_pts.T[0], gp_pts.T[1] - start_time)
with torch.no_grad():
post = model.posterior(gp_pts)
mean = post.mean
std = torch.sqrt(post.variance)
fig, ax = plt.subplots()
ax.set_title("model mean")
ax.set_xlabel("unix time")
ax.set_ylabel("x")
c = ax.pcolor(tt, xx, mean.reshape(n, n), rasterized=True)
ax.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1", label="samples")
ax.plot(t, k(t - start_time), "C3--", label="ideal path", zorder=10)
ax.legend()
fig.colorbar(c)
fig2, ax2 = plt.subplots()
ax2.set_title("model uncertainty")
ax2.set_xlabel("unix time")
ax2.set_ylabel("x")
c = ax2.pcolor(tt, xx, std.reshape(n, n))
fig2.colorbar(c)
fig3, ax3 = plt.subplots()
ax3.set_title("ground truth value")
ax3.set_xlabel("unix time")
ax3.set_ylabel("x")
c = ax3.pcolor(tt, xx, gt_vals.reshape(n, n))
fig3.colorbar(c)
ax2.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1")
ax3.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1")
plot the acquisition function¶
# note that target time is only updated during the generate call
target_time = X.generator.target_prediction_time
print(target_time - start_time)
my_acq_func = X.generator.get_acquisition(model)
with torch.no_grad():
acq_pts = x.unsqueeze(-1).unsqueeze(-1)
full_acq = my_acq_func.acq_func(gp_pts.unsqueeze(1))
fixed_acq = my_acq_func(acq_pts)
fig, ax = plt.subplots()
c = ax.pcolor(tt, xx, full_acq.reshape(n, n))
ax.set_xlabel("unix time")
ax.set_ylabel("x")
ax.set_title("acquisition function")
fig.colorbar(c)
fi2, ax2 = plt.subplots()
ax2.plot(x.flatten(), fixed_acq.flatten())
ax2.set_xlabel("x")
ax2.set_ylabel("acquisition function")
ax2.set_title("acquisition function at last time step")
224.12396121025085
Run Time Dependent BO with Model Caching¶
Instead of retraining the GP model hyperparameters at every step, we can instead hold
on to previously determined model parameters by setting
use_catched_hyperparameters=True in the model constructor. This reduces the time
needed to make decisions, leading to faster feedback when addressing time-critical
optimization tasks. However, this can come at the cost of model accuracy when the
target function changes behavior (change in lengthscale for example).
generator = TDUpperConfidenceBoundGenerator(
vocs=vocs,
beta=0.01,
added_time=0.1,
forgetting_time=20.0,
)
generator.n_monte_carlo_samples = N_MC_SAMPLES
generator.numerical_optimizer.n_restarts = NUM_RESTARTS
generator.max_travel_distances = [0.1]
start_time = time.time()
X = Xopt(evaluator=evaluator, generator=generator)
X.random_evaluate(2)
for i in trange(N_STEPS):
# note that in this example we can ignore warnings if computation time is greater
# than added time
if i == 50:
X.generator.gp_constructor.use_cached_hyperparameters = True
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
X.step()
time.sleep(0.1)
0%| | 0/250 [00:00<?, ?it/s]
0%| | 1/250 [00:00<01:30, 2.76it/s]
1%| | 2/250 [00:00<01:26, 2.86it/s]
1%| | 3/250 [00:01<01:47, 2.30it/s]
2%|▏ | 4/250 [00:01<01:52, 2.18it/s]
2%|▏ | 5/250 [00:02<01:44, 2.35it/s]
2%|▏ | 6/250 [00:02<02:15, 1.80it/s]
3%|▎ | 7/250 [00:03<02:21, 1.71it/s]
3%|▎ | 8/250 [00:04<02:32, 1.59it/s]
4%|▎ | 9/250 [00:04<02:28, 1.62it/s]
4%|▍ | 10/250 [00:05<02:26, 1.64it/s]
4%|▍ | 11/250 [00:05<02:16, 1.75it/s]
5%|▍ | 12/250 [00:06<02:14, 1.77it/s]
5%|▌ | 13/250 [00:07<02:16, 1.74it/s]
6%|▌ | 14/250 [00:07<02:12, 1.78it/s]
6%|▌ | 15/250 [00:08<02:09, 1.82it/s]
6%|▋ | 16/250 [00:09<02:55, 1.33it/s]
7%|▋ | 17/250 [00:10<03:35, 1.08it/s]
7%|▋ | 18/250 [00:11<03:34, 1.08it/s]
8%|▊ | 19/250 [00:13<04:23, 1.14s/it]
8%|▊ | 20/250 [00:14<04:44, 1.24s/it]
8%|▊ | 21/250 [00:15<04:31, 1.18s/it]
9%|▉ | 22/250 [00:16<03:56, 1.04s/it]
9%|▉ | 23/250 [00:17<03:32, 1.07it/s]
10%|▉ | 24/250 [00:17<03:13, 1.17it/s]
10%|█ | 25/250 [00:18<03:01, 1.24it/s]
10%|█ | 26/250 [00:19<02:53, 1.29it/s]
11%|█ | 27/250 [00:20<02:58, 1.25it/s]
11%|█ | 28/250 [00:20<02:42, 1.37it/s]
12%|█▏ | 29/250 [00:21<02:40, 1.38it/s]
12%|█▏ | 30/250 [00:22<02:50, 1.29it/s]
12%|█▏ | 31/250 [00:23<02:57, 1.24it/s]
13%|█▎ | 32/250 [00:24<02:59, 1.22it/s]
13%|█▎ | 33/250 [00:24<02:52, 1.26it/s]
14%|█▎ | 34/250 [00:25<02:39, 1.36it/s]
14%|█▍ | 35/250 [00:26<02:51, 1.25it/s]
14%|█▍ | 36/250 [00:26<02:45, 1.30it/s]
15%|█▍ | 37/250 [00:27<02:32, 1.40it/s]
15%|█▌ | 38/250 [00:28<02:43, 1.30it/s]
16%|█▌ | 39/250 [00:29<02:33, 1.37it/s]
16%|█▌ | 40/250 [00:29<02:12, 1.58it/s]
16%|█▋ | 41/250 [00:30<02:05, 1.66it/s]
17%|█▋ | 42/250 [00:30<02:26, 1.42it/s]
17%|█▋ | 43/250 [00:31<02:13, 1.55it/s]
18%|█▊ | 44/250 [00:32<02:09, 1.58it/s]
18%|█▊ | 45/250 [00:32<01:53, 1.80it/s]
18%|█▊ | 46/250 [00:33<01:52, 1.82it/s]
19%|█▉ | 47/250 [00:33<01:57, 1.73it/s]
19%|█▉ | 48/250 [00:34<01:59, 1.69it/s]
20%|█▉ | 49/250 [00:34<01:49, 1.84it/s]
20%|██ | 50/250 [00:35<02:10, 1.53it/s]
20%|██ | 51/250 [00:35<01:49, 1.81it/s]
21%|██ | 52/250 [00:36<01:47, 1.85it/s]
21%|██ | 53/250 [00:36<01:41, 1.94it/s]
22%|██▏ | 54/250 [00:37<01:33, 2.10it/s]
22%|██▏ | 55/250 [00:37<01:21, 2.38it/s]
22%|██▏ | 56/250 [00:38<01:41, 1.91it/s]
23%|██▎ | 57/250 [00:38<01:42, 1.89it/s]
23%|██▎ | 58/250 [00:39<01:32, 2.08it/s]
24%|██▎ | 59/250 [00:39<01:29, 2.13it/s]
24%|██▍ | 60/250 [00:40<01:38, 1.93it/s]
24%|██▍ | 61/250 [00:40<01:36, 1.96it/s]
25%|██▍ | 62/250 [00:41<01:23, 2.25it/s]
25%|██▌ | 63/250 [00:41<01:20, 2.33it/s]
26%|██▌ | 64/250 [00:41<01:14, 2.49it/s]
26%|██▌ | 65/250 [00:42<01:08, 2.69it/s]
26%|██▋ | 66/250 [00:42<01:03, 2.88it/s]
27%|██▋ | 67/250 [00:42<00:59, 3.08it/s]
27%|██▋ | 68/250 [00:43<01:05, 2.78it/s]
28%|██▊ | 69/250 [00:43<01:02, 2.89it/s]
28%|██▊ | 70/250 [00:43<00:57, 3.13it/s]
28%|██▊ | 71/250 [00:43<00:53, 3.36it/s]
29%|██▉ | 72/250 [00:44<00:52, 3.38it/s]
29%|██▉ | 73/250 [00:44<00:49, 3.59it/s]
30%|██▉ | 74/250 [00:44<00:56, 3.14it/s]
30%|███ | 75/250 [00:45<00:50, 3.49it/s]
30%|███ | 76/250 [00:45<00:49, 3.54it/s]
31%|███ | 77/250 [00:45<00:54, 3.18it/s]
31%|███ | 78/250 [00:46<00:52, 3.25it/s]
32%|███▏ | 79/250 [00:46<00:51, 3.33it/s]
32%|███▏ | 80/250 [00:46<00:50, 3.34it/s]
32%|███▏ | 81/250 [00:47<00:54, 3.12it/s]
33%|███▎ | 82/250 [00:47<01:24, 2.00it/s]
33%|███▎ | 83/250 [00:48<01:16, 2.19it/s]
34%|███▎ | 84/250 [00:48<01:09, 2.39it/s]
34%|███▍ | 85/250 [00:49<01:08, 2.42it/s]
34%|███▍ | 86/250 [00:49<01:04, 2.53it/s]
35%|███▍ | 87/250 [00:49<01:00, 2.69it/s]
35%|███▌ | 88/250 [00:49<00:53, 3.02it/s]
36%|███▌ | 89/250 [00:50<00:51, 3.15it/s]
36%|███▌ | 90/250 [00:50<01:11, 2.24it/s]
36%|███▋ | 91/250 [00:51<01:10, 2.26it/s]
37%|███▋ | 92/250 [00:51<01:06, 2.37it/s]
37%|███▋ | 93/250 [00:52<01:03, 2.47it/s]
38%|███▊ | 94/250 [00:52<01:02, 2.48it/s]
38%|███▊ | 95/250 [00:52<01:04, 2.41it/s]
38%|███▊ | 96/250 [00:53<01:01, 2.51it/s]
39%|███▉ | 97/250 [00:53<01:02, 2.45it/s]
39%|███▉ | 98/250 [00:54<01:04, 2.36it/s]
40%|███▉ | 99/250 [00:54<01:01, 2.46it/s]
40%|████ | 100/250 [00:55<01:04, 2.32it/s]
40%|████ | 101/250 [00:55<01:02, 2.38it/s]
41%|████ | 102/250 [00:55<01:02, 2.37it/s]
41%|████ | 103/250 [00:56<01:05, 2.24it/s]
42%|████▏ | 104/250 [00:56<01:07, 2.17it/s]
42%|████▏ | 105/250 [00:57<01:07, 2.15it/s]
42%|████▏ | 106/250 [00:57<01:10, 2.04it/s]
43%|████▎ | 107/250 [00:58<01:29, 1.59it/s]
43%|████▎ | 108/250 [00:59<01:42, 1.39it/s]
44%|████▎ | 109/250 [01:00<01:39, 1.42it/s]
44%|████▍ | 110/250 [01:00<01:27, 1.60it/s]
44%|████▍ | 111/250 [01:01<01:12, 1.92it/s]
45%|████▍ | 112/250 [01:01<01:12, 1.91it/s]
45%|████▌ | 113/250 [01:02<01:09, 1.96it/s]
46%|████▌ | 114/250 [01:02<01:07, 2.03it/s]
46%|████▌ | 115/250 [01:03<01:03, 2.12it/s]
46%|████▋ | 116/250 [01:03<00:57, 2.33it/s]
47%|████▋ | 117/250 [01:03<00:58, 2.27it/s]
47%|████▋ | 118/250 [01:04<00:55, 2.39it/s]
48%|████▊ | 119/250 [01:04<00:47, 2.77it/s]
48%|████▊ | 120/250 [01:04<00:44, 2.95it/s]
48%|████▊ | 121/250 [01:05<00:46, 2.77it/s]
49%|████▉ | 122/250 [01:05<00:45, 2.83it/s]
49%|████▉ | 123/250 [01:05<00:49, 2.57it/s]
50%|████▉ | 124/250 [01:06<00:47, 2.65it/s]
50%|█████ | 125/250 [01:06<00:46, 2.67it/s]
50%|█████ | 126/250 [01:07<00:46, 2.68it/s]
51%|█████ | 127/250 [01:07<00:47, 2.61it/s]
51%|█████ | 128/250 [01:07<00:48, 2.52it/s]
52%|█████▏ | 129/250 [01:08<00:48, 2.51it/s]
52%|█████▏ | 130/250 [01:08<00:48, 2.45it/s]
52%|█████▏ | 131/250 [01:09<00:51, 2.33it/s]
53%|█████▎ | 132/250 [01:09<00:55, 2.14it/s]
53%|█████▎ | 133/250 [01:10<00:57, 2.05it/s]
54%|█████▎ | 134/250 [01:11<01:13, 1.58it/s]
54%|█████▍ | 135/250 [01:11<00:58, 1.96it/s]
54%|█████▍ | 136/250 [01:11<00:50, 2.26it/s]
55%|█████▍ | 137/250 [01:12<00:42, 2.65it/s]
55%|█████▌ | 138/250 [01:12<00:37, 3.03it/s]
56%|█████▌ | 139/250 [01:12<00:33, 3.34it/s]
56%|█████▌ | 140/250 [01:12<00:30, 3.61it/s]
56%|█████▋ | 141/250 [01:12<00:28, 3.85it/s]
57%|█████▋ | 142/250 [01:13<00:26, 4.00it/s]
57%|█████▋ | 143/250 [01:13<00:25, 4.13it/s]
58%|█████▊ | 144/250 [01:13<00:29, 3.57it/s]
58%|█████▊ | 145/250 [01:13<00:27, 3.78it/s]
58%|█████▊ | 146/250 [01:14<00:27, 3.80it/s]
59%|█████▉ | 147/250 [01:14<00:27, 3.75it/s]
59%|█████▉ | 148/250 [01:14<00:25, 3.93it/s]
60%|█████▉ | 149/250 [01:14<00:26, 3.88it/s]
60%|██████ | 150/250 [01:15<00:27, 3.58it/s]
60%|██████ | 151/250 [01:15<00:26, 3.72it/s]
61%|██████ | 152/250 [01:15<00:25, 3.82it/s]
61%|██████ | 153/250 [01:16<00:25, 3.75it/s]
62%|██████▏ | 154/250 [01:16<00:25, 3.72it/s]
62%|██████▏ | 155/250 [01:16<00:25, 3.78it/s]
62%|██████▏ | 156/250 [01:16<00:24, 3.88it/s]
63%|██████▎ | 157/250 [01:17<00:23, 4.00it/s]
63%|██████▎ | 158/250 [01:17<00:22, 4.03it/s]
64%|██████▎ | 159/250 [01:17<00:22, 4.07it/s]
64%|██████▍ | 160/250 [01:17<00:23, 3.89it/s]
64%|██████▍ | 161/250 [01:18<00:22, 4.01it/s]
65%|██████▍ | 162/250 [01:18<00:21, 4.14it/s]
65%|██████▌ | 163/250 [01:18<00:21, 4.13it/s]
66%|██████▌ | 164/250 [01:18<00:23, 3.73it/s]
66%|██████▌ | 165/250 [01:19<00:22, 3.78it/s]
66%|██████▋ | 166/250 [01:19<00:21, 3.90it/s]
67%|██████▋ | 167/250 [01:19<00:20, 4.01it/s]
67%|██████▋ | 168/250 [01:19<00:19, 4.11it/s]
68%|██████▊ | 169/250 [01:20<00:21, 3.82it/s]
68%|██████▊ | 170/250 [01:20<00:20, 3.95it/s]
68%|██████▊ | 171/250 [01:20<00:19, 4.09it/s]
69%|██████▉ | 172/250 [01:20<00:21, 3.65it/s]
69%|██████▉ | 173/250 [01:21<00:20, 3.80it/s]
70%|██████▉ | 174/250 [01:21<00:21, 3.58it/s]
70%|███████ | 175/250 [01:21<00:19, 3.79it/s]
70%|███████ | 176/250 [01:21<00:18, 3.99it/s]
71%|███████ | 177/250 [01:22<00:17, 4.13it/s]
71%|███████ | 178/250 [01:22<00:17, 4.19it/s]
72%|███████▏ | 179/250 [01:22<00:16, 4.25it/s]
72%|███████▏ | 180/250 [01:22<00:16, 4.35it/s]
72%|███████▏ | 181/250 [01:23<00:17, 3.98it/s]
73%|███████▎ | 182/250 [01:23<00:16, 4.14it/s]
73%|███████▎ | 183/250 [01:23<00:17, 3.82it/s]
74%|███████▎ | 184/250 [01:23<00:16, 3.96it/s]
74%|███████▍ | 185/250 [01:24<00:16, 4.04it/s]
74%|███████▍ | 186/250 [01:24<00:15, 4.14it/s]
75%|███████▍ | 187/250 [01:24<00:14, 4.24it/s]
75%|███████▌ | 188/250 [01:24<00:14, 4.27it/s]
76%|███████▌ | 189/250 [01:25<00:14, 4.31it/s]
76%|███████▌ | 190/250 [01:25<00:18, 3.23it/s]
76%|███████▋ | 191/250 [01:25<00:18, 3.27it/s]
77%|███████▋ | 192/250 [01:26<00:16, 3.53it/s]
77%|███████▋ | 193/250 [01:26<00:15, 3.69it/s]
78%|███████▊ | 194/250 [01:26<00:14, 3.80it/s]
78%|███████▊ | 195/250 [01:26<00:13, 3.95it/s]
78%|███████▊ | 196/250 [01:27<00:13, 4.07it/s]
79%|███████▉ | 197/250 [01:27<00:12, 4.16it/s]
79%|███████▉ | 198/250 [01:27<00:12, 4.27it/s]
80%|███████▉ | 199/250 [01:27<00:11, 4.31it/s]
80%|████████ | 200/250 [01:27<00:11, 4.39it/s]
80%|████████ | 201/250 [01:28<00:10, 4.48it/s]
81%|████████ | 202/250 [01:28<00:10, 4.53it/s]
81%|████████ | 203/250 [01:28<00:10, 4.62it/s]
82%|████████▏ | 204/250 [01:28<00:09, 4.62it/s]
82%|████████▏ | 205/250 [01:28<00:10, 4.50it/s]
82%|████████▏ | 206/250 [01:29<00:09, 4.56it/s]
83%|████████▎ | 207/250 [01:29<00:09, 4.55it/s]
83%|████████▎ | 208/250 [01:29<00:09, 4.56it/s]
84%|████████▎ | 209/250 [01:29<00:08, 4.56it/s]
84%|████████▍ | 210/250 [01:30<00:08, 4.57it/s]
84%|████████▍ | 211/250 [01:30<00:08, 4.51it/s]
85%|████████▍ | 212/250 [01:30<00:08, 4.45it/s]
85%|████████▌ | 213/250 [01:30<00:08, 4.40it/s]
86%|████████▌ | 214/250 [01:30<00:08, 4.38it/s]
86%|████████▌ | 215/250 [01:31<00:08, 4.37it/s]
86%|████████▋ | 216/250 [01:31<00:07, 4.27it/s]
87%|████████▋ | 217/250 [01:31<00:07, 4.20it/s]
87%|████████▋ | 218/250 [01:31<00:07, 4.27it/s]
88%|████████▊ | 219/250 [01:32<00:07, 4.32it/s]
88%|████████▊ | 220/250 [01:32<00:07, 3.96it/s]
88%|████████▊ | 221/250 [01:32<00:07, 4.03it/s]
89%|████████▉ | 222/250 [01:32<00:06, 4.23it/s]
89%|████████▉ | 223/250 [01:33<00:06, 4.34it/s]
90%|████████▉ | 224/250 [01:33<00:05, 4.35it/s]
90%|█████████ | 225/250 [01:33<00:05, 4.43it/s]
90%|█████████ | 226/250 [01:33<00:05, 4.38it/s]
91%|█████████ | 227/250 [01:34<00:05, 4.40it/s]
91%|█████████ | 228/250 [01:34<00:04, 4.46it/s]
92%|█████████▏| 229/250 [01:34<00:04, 4.49it/s]
92%|█████████▏| 230/250 [01:34<00:04, 4.51it/s]
92%|█████████▏| 231/250 [01:34<00:04, 4.53it/s]
93%|█████████▎| 232/250 [01:35<00:03, 4.57it/s]
93%|█████████▎| 233/250 [01:35<00:03, 4.53it/s]
94%|█████████▎| 234/250 [01:35<00:03, 4.55it/s]
94%|█████████▍| 235/250 [01:35<00:03, 4.49it/s]
94%|█████████▍| 236/250 [01:36<00:03, 4.48it/s]
95%|█████████▍| 237/250 [01:36<00:02, 4.43it/s]
95%|█████████▌| 238/250 [01:36<00:02, 4.38it/s]
96%|█████████▌| 239/250 [01:36<00:02, 4.36it/s]
96%|█████████▌| 240/250 [01:36<00:02, 4.38it/s]
96%|█████████▋| 241/250 [01:37<00:02, 4.21it/s]
97%|█████████▋| 242/250 [01:37<00:01, 4.19it/s]
97%|█████████▋| 243/250 [01:37<00:01, 4.23it/s]
98%|█████████▊| 244/250 [01:37<00:01, 4.30it/s]
98%|█████████▊| 245/250 [01:38<00:01, 4.29it/s]
98%|█████████▊| 246/250 [01:39<00:02, 2.00it/s]
99%|█████████▉| 247/250 [01:40<00:02, 1.34it/s]
99%|█████████▉| 248/250 [01:40<00:01, 1.65it/s]
100%|█████████▉| 249/250 [01:41<00:00, 1.73it/s]
100%|██████████| 250/250 [01:41<00:00, 2.15it/s]
100%|██████████| 250/250 [01:41<00:00, 2.46it/s]
# plot total computation time
ax = X.generator.computation_time.sum(axis=1).plot()
ax.set_xlabel("Iteration")
ax.set_ylabel("total BO computation time (s)")
Text(0, 0.5, 'total BO computation time (s)')
data = X.data
xbounds = generator.vocs.bounds
tbounds = [data["time"].min(), data["time"].max()]
model = X.generator.model
n = 100
t = torch.linspace(*tbounds, n, dtype=torch.double)
x = torch.linspace(*torch.tensor(xbounds).flatten(), n, dtype=torch.double)
tt, xx = torch.meshgrid(t, x)
pts = torch.hstack([ele.reshape(-1, 1) for ele in (tt, xx)]).double()
tt, xx = tt.numpy(), xx.numpy()
# NOTE: the model inputs are such that t is the last dimension
gp_pts = torch.flip(pts, dims=[-1])
gt_vals = g(gp_pts.T[0], gp_pts.T[1] - start_time)
with torch.no_grad():
post = model.posterior(gp_pts)
mean = post.mean
std = torch.sqrt(post.variance)
fig, ax = plt.subplots()
ax.set_title("model mean")
ax.set_xlabel("unix time")
ax.set_ylabel("x")
c = ax.pcolor(tt, xx, mean.reshape(n, n))
ax.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1", label="samples")
ax.plot(t, k(t - start_time), "C3--", label="ideal path", zorder=10)
ax.legend()
fig.colorbar(c)
fig2, ax2 = plt.subplots()
ax2.set_title("model uncertainty")
ax2.set_xlabel("unix time")
ax2.set_ylabel("x")
c = ax2.pcolor(tt, xx, std.reshape(n, n))
fig2.colorbar(c)
fig3, ax3 = plt.subplots()
ax3.set_title("ground truth value")
ax3.set_xlabel("unix time")
ax3.set_ylabel("x")
c = ax3.pcolor(tt, xx, gt_vals.reshape(n, n))
fig3.colorbar(c)
ax2.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1")
ax3.plot(data["time"].to_numpy(), data["x"].to_numpy(), "oC1")