Skip to content

Optimization API

Schedule Classes

spectracles.optimise.opt_schedule.OptimiserSchedule

Multi-phase optimization schedule for ShareModule models.

Runs a sequence of optimization phases, each with its own optimizer, learning rate, and parameter configuration. Tracks model history across phases.

Attributes:

Name Type Description
phases

List of Phase objects defining the schedule.

model_history

List of models, one for initial state plus one after each completed phase.

Example

schedule = OptimiserSchedule( ... model, loss_fn, ... phase_configs=[ ... PhaseConfig(n_steps=100, optimiser=optax.adam(0.1)), ... PhaseConfig(n_steps=50, optimiser=optax.adam(0.01)), ... ], ... ) schedule.run_all(x=data_x, y=data_y) final_model = schedule.model_history[-1]

See Also

ManagedOptimiserSchedule: Version with state tracking, skip, and reset. build_schedule: Declarative API for building schedules.

loss_history property

Get the total loss history from all phases.

loss_histories property

Get the total loss history from all phases.

run_all(*loss_args, **loss_kwargs)

Run all phases in the schedule.

run_phase(phase, *loss_args, **loss_kwargs)

Run a single phase in the schedule.

spectracles.optimise.opt_schedule.ManagedOptimiserSchedule

Multi-phase optimization schedule with state tracking and control.

Extends OptimiserSchedule with phase state management, allowing you to: - Run phases one at a time with run_next_phase() - Skip phases with skip_phase() - Reset and re-run with reset() or reset_from_phase() - Inspect progress with get_phase_status(), is_complete(), etc.

Phases must be run in order. The schedule tracks which phases are pending, running, completed, or skipped.

Attributes:

Name Type Description
phases

List of Phase objects defining the schedule.

model_history

List of models, one for initial state plus one after each completed phase.

phase_states

List of PhaseState values tracking each phase's status.

current_phase_index

Index of the next phase to run.

initial_model

The original model (kept for reset functionality).

Example

schedule = ManagedOptimiserSchedule(model, loss_fn, phase_configs)

Run phases one at a time

schedule.run_next_phase(x=x, y=y) print(f"Phase 0 loss: {schedule.loss_history[-1]:.4f}")

Check status

print(schedule.get_phase_status()) {'current_phase': 1, 'total_phases': 3, 'completed_phases': 1, ...}

Skip a phase if needed

schedule.skip_phase(1)

Run remaining phases

schedule.run_all(x=x, y=y)

Reset and try again

schedule.reset() schedule.run_all(x=x, y=y)

See Also

OptimiserSchedule: Simpler version without state tracking. build_schedule: Declarative API with managed=True option.

loss_history property

Get the total loss history from completed phases.

loss_histories property

Get the loss history from completed phases only.

run_all(*loss_args, **loss_kwargs)

Run all remaining phases in the schedule.

run_next_phase(*loss_args, **loss_kwargs)

Run the next pending phase. Returns True if a phase was run, False if all complete.

run_phase_by_index(phase_idx, *loss_args, **loss_kwargs)

Run a specific phase by index, with validation.

run_phases(phase_indices, *loss_args, **loss_kwargs)

Run multiple phases by index.

skip_phase(phase_idx)

Skip a phase (mark as completed without running).

reset()

Reset the schedule to initial state with fresh optimizer frames.

reset_from_phase(phase_idx)

Reset schedule from a specific phase onwards with fresh optimizer frames.

get_phase_status()

Get detailed status of all phases.

is_complete()

Check if all phases are completed.

get_next_phase_index()

Get the index of the next phase to run, or None if complete.

get_completed_phases()

Get indices of all completed phases.

get_pending_phases()

Get indices of all pending phases.

Configuration

spectracles.optimise.opt_schedule.PhaseConfig dataclass

Configuration for a single optimization phase.

Defines the number of steps, optimizer, and any parameter updates to apply at the start of the phase.

Attributes:

Name Type Description
n_steps int

Number of optimization steps to run in this phase.

optimiser GradientTransformation

An optax optimizer (e.g., optax.adam(0.01)).

Δloss_criterion float

Convergence criterion for loss change (default 1e2).

fix_status_updates dict[str, bool]

Dict mapping parameter paths to their fixed status. True means fixed (not optimized), False means free.

param_val_updates dict[str, Array]

Dict mapping parameter paths to new values to set at the start of this phase.

Example

config = PhaseConfig( ... n_steps=100, ... optimiser=optax.adam(0.01), ... fix_status_updates={"gp.kernel.lengthscale": True}, ... param_val_updates={"gp.coefficients": jnp.zeros(10)}, ... )

spectracles.optimise.opt_schedule.PhaseState

Bases: Enum

State of a phase in ManagedOptimiserSchedule.

Attributes:

Name Type Description
PENDING

Phase has not been run yet.

RUNNING

Phase is currently executing.

COMPLETED

Phase finished successfully.

SKIPPED

Phase was skipped via skip_phase().