Optimization API¶
Schedule Classes¶
spectracles.optimise.opt_schedule.OptimiserSchedule
¶
Multi-phase optimization schedule for ShareModule models.
Runs a sequence of optimization phases, each with its own optimizer, learning rate, and parameter configuration. Tracks model history across phases.
Attributes:
| Name | Type | Description |
|---|---|---|
phases |
List of Phase objects defining the schedule. |
|
model_history |
List of models, one for initial state plus one after each completed phase. |
Example
schedule = OptimiserSchedule( ... model, loss_fn, ... phase_configs=[ ... PhaseConfig(n_steps=100, optimiser=optax.adam(0.1)), ... PhaseConfig(n_steps=50, optimiser=optax.adam(0.01)), ... ], ... ) schedule.run_all(x=data_x, y=data_y) final_model = schedule.model_history[-1]
See Also
ManagedOptimiserSchedule: Version with state tracking, skip, and reset. build_schedule: Declarative API for building schedules.
spectracles.optimise.opt_schedule.ManagedOptimiserSchedule
¶
Multi-phase optimization schedule with state tracking and control.
Extends OptimiserSchedule with phase state management, allowing you to:
- Run phases one at a time with run_next_phase()
- Skip phases with skip_phase()
- Reset and re-run with reset() or reset_from_phase()
- Inspect progress with get_phase_status(), is_complete(), etc.
Phases must be run in order. The schedule tracks which phases are pending, running, completed, or skipped.
Attributes:
| Name | Type | Description |
|---|---|---|
phases |
List of Phase objects defining the schedule. |
|
model_history |
List of models, one for initial state plus one after each completed phase. |
|
phase_states |
List of PhaseState values tracking each phase's status. |
|
current_phase_index |
Index of the next phase to run. |
|
initial_model |
The original model (kept for reset functionality). |
Example
schedule = ManagedOptimiserSchedule(model, loss_fn, phase_configs)
Run phases one at a time¶
schedule.run_next_phase(x=x, y=y) print(f"Phase 0 loss: {schedule.loss_history[-1]:.4f}")
Check status¶
print(schedule.get_phase_status()) {'current_phase': 1, 'total_phases': 3, 'completed_phases': 1, ...}
Skip a phase if needed¶
schedule.skip_phase(1)
Run remaining phases¶
schedule.run_all(x=x, y=y)
Reset and try again¶
schedule.reset() schedule.run_all(x=x, y=y)
See Also
OptimiserSchedule: Simpler version without state tracking.
build_schedule: Declarative API with managed=True option.
loss_history
property
¶
Get the total loss history from completed phases.
loss_histories
property
¶
Get the loss history from completed phases only.
run_all(*loss_args, **loss_kwargs)
¶
Run all remaining phases in the schedule.
run_next_phase(*loss_args, **loss_kwargs)
¶
Run the next pending phase. Returns True if a phase was run, False if all complete.
run_phase_by_index(phase_idx, *loss_args, **loss_kwargs)
¶
Run a specific phase by index, with validation.
run_phases(phase_indices, *loss_args, **loss_kwargs)
¶
Run multiple phases by index.
skip_phase(phase_idx)
¶
Skip a phase (mark as completed without running).
reset()
¶
Reset the schedule to initial state with fresh optimizer frames.
reset_from_phase(phase_idx)
¶
Reset schedule from a specific phase onwards with fresh optimizer frames.
get_phase_status()
¶
Get detailed status of all phases.
is_complete()
¶
Check if all phases are completed.
get_next_phase_index()
¶
Get the index of the next phase to run, or None if complete.
get_completed_phases()
¶
Get indices of all completed phases.
get_pending_phases()
¶
Get indices of all pending phases.
Configuration¶
spectracles.optimise.opt_schedule.PhaseConfig
dataclass
¶
Configuration for a single optimization phase.
Defines the number of steps, optimizer, and any parameter updates to apply at the start of the phase.
Attributes:
| Name | Type | Description |
|---|---|---|
n_steps |
int
|
Number of optimization steps to run in this phase. |
optimiser |
GradientTransformation
|
An optax optimizer (e.g., |
Δloss_criterion |
float
|
Convergence criterion for loss change (default 1e2). |
fix_status_updates |
dict[str, bool]
|
Dict mapping parameter paths to their fixed status. True means fixed (not optimized), False means free. |
param_val_updates |
dict[str, Array]
|
Dict mapping parameter paths to new values to set at the start of this phase. |
Example
config = PhaseConfig( ... n_steps=100, ... optimiser=optax.adam(0.01), ... fix_status_updates={"gp.kernel.lengthscale": True}, ... param_val_updates={"gp.coefficients": jnp.zeros(10)}, ... )
spectracles.optimise.opt_schedule.PhaseState
¶
Bases: Enum
State of a phase in ManagedOptimiserSchedule.
Attributes:
| Name | Type | Description |
|---|---|---|
PENDING |
Phase has not been run yet. |
|
RUNNING |
Phase is currently executing. |
|
COMPLETED |
Phase finished successfully. |
|
SKIPPED |
Phase was skipped via skip_phase(). |