Continous Multi-fidelity Optimization
[14]:
try:
import google.colab
IN_COLAB = True
except:
IN_COLAB = False
[15]:
# For notebook gallery thumbnail
if IN_COLAB:
from IPython.display import Image
Image(filename="continuous-multi-fidelity.png")
In the previous notebook, I provided a brief introduction of multi-fidelity optimization in the context of the physical sciences. This notebook will cover Bayesian optimization using two continuous fidelity parameters (atime and astep). We’ll compare the total integration time using the multi-fidelity optimization with the integration time costs of running the
simulation at the the lowest fidelities, the highest fidelities (default), and approximately halfway in-between. For validation, the objective (in this case, frechet) will be evaluated at the upper bound of atime and astep. In this experiment, we’ll allow atime to vary between 0..100 (upper limit is 255) which is: \(2.78 \mu s .. 280.78\mu s\). Since multiple fidelity parameters aren’t yet supported with Ax (i.e. requires
using BoTorch directly), we’ll fix astep to 999 (limits are 0..65534). In terms of physical integration time, this corresponds to lowest and highest integration times of \(278 \mu s\) and \(281 ms\), respectively.
For reference, here are the docs for atime and astep (taken from `public_mqtt_sdl_demo/lib/as7341_sensor.py <https://github.com/sparks-baird/self-driving-lab-demo/blob/main/src/public_mqtt_sdl_demo/lib/as7341_sensor.py>`__):
"""
...
atime : int, optional
The integration time step size in 2.78 microsecond increments, by default 100
astep : int, optional
The integration time step count. Total integration time will be (ATIME + 1)
* (ASTEP + 1) * 2.78µS, by default 999, meaning 281 ms assuming atime=100
...
"""
We’ll be using the Knowledge Gradient (KG) acquisition function.
From the BoTorch docs: > ### The one-shot Knowledge Gradient acquisition function > The Knowledge Gradient (KG) (see [2, 3]) is a look-ahead acquisition function that quantifies the expected increase in the maximum of the modeled black-box function f from obtaining additional (random) observations collected at the candidate set x. KG often shows improved Bayesian Optimization performance relative to simpler acquisition functions such as Expected Improvement, but in its traditional form it is computationally expensive and hard to implement. > > … > > [2] P. Frazier, W. Powell, and S. Dayanik. A Knowledge-Gradient policy for sequential information collection. SIAM Journal on Control and Optimization, 2008. > > [3] J. Wu and P. Frazier. The parallel knowledge gradient method for batch bayesian optimization. NIPS 2016.
In this tutorial, we’ll still use Ax rather than delving into pure BoTorch code. Since Ax doesn’t yet support discrete fidelity parameters, for the next notebook, we’ll use BoTorch exclusively. BoTorch has a tutorial for continuous multi-fidelity optimization which has been adapted for Ax in a GitHub issue. We will largely base the implementation on the example from the GitHub issue.
As before, we need to set up our SelfDrivingLabDemo classes. We will use the physical experimental setting since the simulations use only a simple multiplier to account for integration time. We’ll also use the same frechet objective function.
[16]:
if IN_COLAB:
%pip install self-driving-lab-demo
[17]:
from uuid import uuid4 # universally unique identifier
from self_driving_lab_demo import SelfDrivingLabDemoLight, mqtt_observe_sensor_data
dummy = True # @param {type:"boolean"}
pico_id = "test" # @param {type:"string"}
log_to_mongodb = False # speed up the evaluation
if dummy:
num_repeats = 2
atime_max = 5
astep_max = 5
time_limit_multiplier = 1
else:
num_repeats = 5 # @param {type:"integer"}
atime_max = 100
astep_max = 999 # fixed for now (see note above)
time_limit_multiplier = 10 # @param {type:"number"}
model_gen_kwargs = (
dict(num_fantasies=2, num_restarts=2, raw_samples=8) if dummy else None
)
simulation = False # @param {type:"boolean"}
SESSION_ID = str(uuid4()) # random session ID
def calc_integration_time_s(atime, astep):
"""
Calculate integration time (i.e., time cost) of light sensor.
atime : int, optional
The integration time step size in 2.78 microsecond increments, by default 100
astep : int, optional
The integration time step count. Total integration time will be (ATIME + 1)
* (ASTEP + 1) * 2.78µS, by default 999, meaning 281 ms assuming atime=100
"""
return ((atime + 1) * (astep + 1) * 2.78) / 1e6
# total seconds of integration time
time_limit_s = time_limit_multiplier * calc_integration_time_s(atime_max, astep_max)
seeds = range(10, 10 + num_repeats)
print(f"session ID: {SESSION_ID}")
sdls = [
SelfDrivingLabDemoLight(
autoload=True, # perform target data experiment automatically
simulation=simulation,
observe_sensor_data_fn=mqtt_observe_sensor_data, # (default)
observe_sensor_data_kwargs=dict(
pico_id=pico_id, session_id=SESSION_ID, mongodb=log_to_mongodb
),
target_seed=seed,
)
for seed in seeds
]
session ID: f0d33d7d-c270-474f-a45a-e5873c994745
[18]:
bounds = dict(R=sdls[0].bounds["R"], G=sdls[0].bounds["G"], B=sdls[0].bounds["B"])
params = [dict(name=nm, type="range", bounds=bnd) for nm, bnd in bounds.items()]
atime_bnd = [0, atime_max] # instead of [0, 255]
astep_bnd = [0, astep_max] # instead of [0, 65534]
params.append(
dict(
name="atime",
type="range",
is_fidelity=True,
bounds=atime_bnd,
target_value=atime_bnd[1],
)
)
params.append(
dict(
name="astep",
type="fixed",
value=astep_max,
)
)
params
[18]:
[{'name': 'R', 'type': 'range', 'bounds': [0, 89]},
{'name': 'G', 'type': 'range', 'bounds': [0, 89]},
{'name': 'B', 'type': 'range', 'bounds': [0, 89]},
{'name': 'atime',
'type': 'range',
'is_fidelity': True,
'bounds': [0, 5],
'target_value': 5},
{'name': 'astep', 'type': 'fixed', 'value': 5}]
[19]:
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.utils.instantiation import ObjectiveProperties
import torch
from time import time
integration_time_name = "integration_time_s"
model_runtime_name = "model_runtime_s"
tracking_metric_names = [integration_time_name, model_runtime_name]
campaign_objects = []
objective_name = "frechet"
batch_size = 1
num_sobol = 6
objectives = {objective_name: ObjectiveProperties(minimize=True)}
model_kwargs = (
{"torch_device": torch.device("cuda")} if torch.cuda.is_available() else None
)
print(f"model_kwargs: {model_kwargs}")
for i, sdl in enumerate(sdls):
def evaluate(parameters):
results = sdl.evaluate(parameters)
# # remove channel names to prevent extra tracking metrics warnings
# [results.pop(ch) for ch in sdl.channel_names]
atime = parameters["atime"]
astep = parameters["astep"]
results[integration_time_name] = calc_integration_time_s(atime, astep)
return {objective_name: results[objective_name]}
gs = GenerationStrategy(
steps=[
GenerationStep(model=Models.SOBOL, num_trials=num_sobol),
GenerationStep(
model=Models.GPKG,
num_trials=-1,
model_kwargs=model_kwargs,
model_gen_kwargs=model_gen_kwargs,
),
]
)
ax_client = AxClient(generation_strategy=gs)
ax_client.create_experiment(
name="sdl_demo_mf_experiment",
parameters=params,
objectives=objectives,
tracking_metric_names=tracking_metric_names,
overwrite_existing_experiment=True,
)
running_integration_time_s = 0
# Initial sobol samples
for i in range(num_sobol):
t0 = time()
parameters, trial_index = ax_client.get_next_trial()
model_runtime_dict = {model_runtime_name: time() - t0}
results = evaluate(parameters)
raw_data = {**model_runtime_dict, **results}
ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)
# KGBO
while running_integration_time_s < time_limit_s:
t0 = time()
q_p, q_t = [], []
# Simulate batches
for q in range(batch_size):
parameters, trial_index = ax_client.get_next_trial()
q_p.append(parameters)
q_t.append(trial_index)
model_runtime_dict = {model_runtime_name: time() - t0}
for q in range(batch_size):
pi = q_p[q]
ti = q_t[q]
integration_time = calc_integration_time_s(pi["atime"], pi["astep"])
running_integration_time_s = running_integration_time_s + integration_time
results = evaluate(pi)
if running_integration_time_s > time_limit_s:
# backup the time by one iteration and break
final_cost_s = running_integration_time_s - integration_time
break
raw_data = {**model_runtime_dict, **results}
ax_client.complete_trial(trial_index=ti, raw_data=raw_data)
print(f"Time limit reached: {time_limit_s} s")
print(f"Running integration time {final_cost_s} s")
campaign_objects.append({"campaign_num": i, "sdl": sdl, "ax_client": ax_client})
[INFO 08-09 18:37:54] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter R. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter G. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter B. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter atime. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:37:54] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='R', parameter_type=INT, range=[0, 89]), RangeParameter(name='G', parameter_type=INT, range=[0, 89]), RangeParameter(name='B', parameter_type=INT, range=[0, 89]), RangeParameter(name='atime', parameter_type=INT, range=[0, 5], fidelity=True, target_value=5), FixedParameter(name='astep', parameter_type=INT, value=5)], parameter_constraints=[]).
[INFO 08-09 18:37:54] ax.service.ax_client: Generated new trial 0 with parameters {'R': 76, 'G': 50, 'B': 25, 'atime': 0, 'astep': 5}.
model_kwargs: {'torch_device': device(type='cuda')}
[INFO 08-09 18:37:57] ax.service.ax_client: Completed trial 0 with data: {'model_runtime_s': (0.008999, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:37:57] ax.service.ax_client: Generated new trial 1 with parameters {'R': 82, 'G': 58, 'B': 18, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:38:01] ax.service.ax_client: Completed trial 1 with data: {'model_runtime_s': (0.015611, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:38:01] ax.service.ax_client: Generated new trial 2 with parameters {'R': 55, 'G': 3, 'B': 87, 'atime': 4, 'astep': 5}.
[INFO 08-09 18:38:04] ax.service.ax_client: Completed trial 2 with data: {'model_runtime_s': (0.006849, None), 'frechet': (15567.0, None)}.
[INFO 08-09 18:38:04] ax.service.ax_client: Generated new trial 3 with parameters {'R': 32, 'G': 11, 'B': 18, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:38:08] ax.service.ax_client: Completed trial 3 with data: {'model_runtime_s': (0.005571, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:38:08] ax.service.ax_client: Generated new trial 4 with parameters {'R': 10, 'G': 35, 'B': 0, 'atime': 3, 'astep': 5}.
[INFO 08-09 18:38:11] ax.service.ax_client: Completed trial 4 with data: {'model_runtime_s': (0.013036, None), 'frechet': (15567.388638, None)}.
[INFO 08-09 18:38:11] ax.service.ax_client: Generated new trial 5 with parameters {'R': 42, 'G': 71, 'B': 88, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:38:15] ax.service.ax_client: Completed trial 5 with data: {'model_runtime_s': (0.0115, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:39:09] ax.service.ax_client: Generated new trial 6 with parameters {'R': 79, 'G': 0, 'B': 89, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:39:12] ax.service.ax_client: Completed trial 6 with data: {'model_runtime_s': (53.845453, None), 'frechet': (15568.0, None)}.
[INFO 08-09 18:40:16] ax.service.ax_client: Generated new trial 7 with parameters {'R': 36, 'G': 0, 'B': 89, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:40:19] ax.service.ax_client: Completed trial 7 with data: {'model_runtime_s': (63.513963, None), 'frechet': (15565.0, None)}.
[INFO 08-09 18:41:28] ax.service.ax_client: Generated new trial 8 with parameters {'R': 16, 'G': 0, 'B': 89, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:41:32] ax.service.ax_client: Completed trial 8 with data: {'model_runtime_s': (68.218229, None), 'frechet': (15567.157383, None)}.
[INFO 08-09 18:42:42] ax.service.ax_client: Generated new trial 9 with parameters {'R': 34, 'G': 0, 'B': 69, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:42:45] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 08-09 18:42:45] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter R. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:42:45] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter G. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:42:45] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter B. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:42:45] ax.service.utils.instantiation: Inferred value type of ParameterType.INT for parameter atime. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 08-09 18:42:45] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='R', parameter_type=INT, range=[0, 89]), RangeParameter(name='G', parameter_type=INT, range=[0, 89]), RangeParameter(name='B', parameter_type=INT, range=[0, 89]), RangeParameter(name='atime', parameter_type=INT, range=[0, 5], fidelity=True, target_value=5), FixedParameter(name='astep', parameter_type=INT, value=5)], parameter_constraints=[]).
[INFO 08-09 18:42:45] ax.service.ax_client: Generated new trial 0 with parameters {'R': 22, 'G': 69, 'B': 63, 'atime': 0, 'astep': 5}.
Time limit reached: 0.00010008 s
Running integration time 5.004e-05 s
[INFO 08-09 18:42:48] ax.service.ax_client: Completed trial 0 with data: {'model_runtime_s': (0.0, None), 'frechet': (9536.179791, None)}.
[INFO 08-09 18:42:48] ax.service.ax_client: Generated new trial 1 with parameters {'R': 1, 'G': 59, 'B': 25, 'atime': 3, 'astep': 5}.
[INFO 08-09 18:42:52] ax.service.ax_client: Completed trial 1 with data: {'model_runtime_s': (0.015429, None), 'frechet': (9537.335529, None)}.
[INFO 08-09 18:42:52] ax.service.ax_client: Generated new trial 2 with parameters {'R': 79, 'G': 47, 'B': 76, 'atime': 2, 'astep': 5}.
[INFO 08-09 18:42:55] ax.service.ax_client: Completed trial 2 with data: {'model_runtime_s': (0.00596, None), 'frechet': (9530.180533, None)}.
[INFO 08-09 18:42:55] ax.service.ax_client: Generated new trial 3 with parameters {'R': 26, 'G': 5, 'B': 84, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:42:59] ax.service.ax_client: Completed trial 3 with data: {'model_runtime_s': (0.013637, None), 'frechet': (9533.180162, None)}.
[INFO 08-09 18:42:59] ax.service.ax_client: Generated new trial 4 with parameters {'R': 24, 'G': 26, 'B': 46, 'atime': 2, 'astep': 5}.
[INFO 08-09 18:43:02] ax.service.ax_client: Completed trial 4 with data: {'model_runtime_s': (0.008964, None), 'frechet': (9533.047204, None)}.
[INFO 08-09 18:43:02] ax.service.ax_client: Generated new trial 5 with parameters {'R': 14, 'G': 5, 'B': 35, 'atime': 5, 'astep': 5}.
[INFO 08-09 18:43:05] ax.service.ax_client: Completed trial 5 with data: {'model_runtime_s': (0.008964, None), 'frechet': (9537.335529, None)}.
[INFO 08-09 18:44:10] ax.service.ax_client: Generated new trial 6 with parameters {'R': 74, 'G': 20, 'B': 83, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:44:14] ax.service.ax_client: Completed trial 6 with data: {'model_runtime_s': (64.559284, None), 'frechet': (9535.0, None)}.
[INFO 08-09 18:45:14] ax.service.ax_client: Generated new trial 7 with parameters {'R': 89, 'G': 66, 'B': 88, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:45:17] ax.service.ax_client: Completed trial 7 with data: {'model_runtime_s': (60.053519, None), 'frechet': (9534.047199, None)}.
[INFO 08-09 18:46:21] ax.service.ax_client: Generated new trial 8 with parameters {'R': 86, 'G': 48, 'B': 50, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:46:24] ax.service.ax_client: Completed trial 8 with data: {'model_runtime_s': (63.737873, None), 'frechet': (9534.047199, None)}.
[INFO 08-09 18:47:13] ax.service.ax_client: Generated new trial 9 with parameters {'R': 56, 'G': 50, 'B': 82, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:47:17] ax.service.ax_client: Completed trial 9 with data: {'model_runtime_s': (49.075867, None), 'frechet': (9538.0, None)}.
[INFO 08-09 18:48:25] ax.service.ax_client: Generated new trial 10 with parameters {'R': 89, 'G': 34, 'B': 89, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:48:29] ax.service.ax_client: Completed trial 10 with data: {'model_runtime_s': (67.872802, None), 'frechet': (9538.0, None)}.
[INFO 08-09 18:49:26] ax.service.ax_client: Generated new trial 11 with parameters {'R': 89, 'G': 81, 'B': 59, 'atime': 0, 'astep': 5}.
[INFO 08-09 18:49:29] ax.service.ax_client: Completed trial 11 with data: {'model_runtime_s': (57.316357, None), 'frechet': (9536.179791, None)}.
[INFO 08-09 18:50:28] ax.service.ax_client: Generated new trial 12 with parameters {'R': 74, 'G': 27, 'B': 39, 'atime': 0, 'astep': 5}.
Time limit reached: 0.00010008 s
Running integration time 0.00010007999999999999 s
[20]:
running_integration_time_s
[20]:
0.00011675999999999999
[6]:
ax_clients
[6]:
[AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
AxClient(experiment=Experiment(sdl_demo_mf_experiment)),
AxClient(experiment=Experiment(sdl_demo_mf_experiment))]
Code Graveyard
[ ]:
# params.append(
# dict(
# name="astep",
# type="range",
# is_fidelity=True,
# bounds=astep_bnd,
# target_value=astep_bnd[1],
# )
# )
[ ]:
# from botorch.test_functions.multi_fidelity import AugmentedHartmann
# import torch
[ ]:
# [
# "utc_timestamp",
# "ch470",
# "ch550",
# "ch670",
# "ch410",
# "background", # needs to be collapsed
# "ch620",
# "sd_card_ready",
# "ch510",
# "warning",
# "ch583",
# "device_nickname",
# "ch440",
# "onboard_temperature_K",
# "encrypted_device_id_truncated",
# "mae",
# "rmse",
# "frechet",
# ]