torchebm.core ¶
Core functionality for energy-based models, including energy functions, base sampler class, and training utilities.
AckleyModel ¶
Bases: BaseModel
Energy-based model for the Ackley function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 20.0 |
b | float | The | 0.2 |
c | float | The | 2 * pi |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Ackley energy.
Source code in torchebm/core/base_model.py
BaseContrastiveDivergence ¶
Bases: BaseLoss
Abstract base class for Contrastive Divergence (CD) based loss functions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | BaseModel | The energy-based model to be trained. | required |
sampler | BaseSampler | The MCMC sampler for generating negative samples. | required |
k_steps | int | The number of MCMC steps to perform for each update. | 1 |
persistent | bool | If | False |
buffer_size | int | The size of the replay buffer for PCD. | 100 |
new_sample_ratio | float | The ratio of new random samples to introduce into the MCMC chain. | 0.0 |
init_steps | int | The number of MCMC steps to run when initializing new chain elements. | 0 |
dtype | dtype | Data type for computations. | float32 |
device | Optional[Union[str, device]] | Device for computations. | None |
use_mixed_precision | bool | Whether to use mixed precision training. | False |
clip_value | Optional[float] | Optional value to clamp the loss. | None |
Source code in torchebm/core/base_loss.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 | |
__repr__() ¶
__str__() ¶
compute_loss(x, pred_x, *args, **kwargs) abstractmethod ¶
Computes the contrastive divergence loss from positive and negative samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Real data samples (positive samples). | required |
pred_x | Tensor | Generated negative samples. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The contrastive divergence loss. |
Source code in torchebm/core/base_loss.py
forward(x, *args, **kwargs) abstractmethod ¶
Computes the CD loss given real data samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Real data samples (positive samples). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple[torch.Tensor, torch.Tensor]: - The contrastive divergence loss. - The generated negative samples. |
Source code in torchebm/core/base_loss.py
get_negative_samples(x, batch_size, data_shape) ¶
Gets negative samples using the replay buffer strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | (Unused) The input data tensor. | required | |
batch_size | int | The number of samples to generate. | required |
data_shape | Tuple[int, ...] | The shape of the data samples (excluding batch size). | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: Negative samples. |
Source code in torchebm/core/base_loss.py
get_start_points(x) ¶
Gets the starting points for the MCMC sampler.
For standard CD, this is the input data. For PCD, it's samples from the replay buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | The input data batch. | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The tensor of starting points for the sampler. |
Source code in torchebm/core/base_loss.py
initialize_buffer(data_shape_no_batch, buffer_chunk_size=1024, init_noise_scale=0.01) ¶
Initializes the replay buffer with random noise for PCD.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_shape_no_batch | Tuple[int, ...] | The shape of the data excluding the batch dimension. | required |
buffer_chunk_size | int | The size of chunks to process during initialization. | 1024 |
init_noise_scale | float | The scale of the initial noise. | 0.01 |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The initialized replay buffer. |
Source code in torchebm/core/base_loss.py
update_buffer(samples) ¶
Updates the replay buffer with new samples using a FIFO strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples | Tensor | New samples to add to the buffer. | required |
Source code in torchebm/core/base_loss.py
BaseIntegrator ¶
Bases: DeviceMixin, Module, ABC
Abstract integrator that advances a sampler state according to dynamics.
The integrator operates on a generic state dict to remain reusable across samplers (e.g., Langevin uses only position x, HMC uses position x and momentum p).
Methods follow PyTorch conventions and respect device/dtype from DeviceMixin.
Source code in torchebm/core/base_integrator.py
integrate(state, step_size, n_steps, *args, **kwargs) abstractmethod ¶
Advance the dynamical state by n_steps integrator applications.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping containing required tensors (e.g., {'x': ..., 'p': ...}). | required |
step_size | Tensor | Step size for the integration. | required |
n_steps | int | The number of integration steps to perform. | required |
*args | Additional positional arguments specific to the integrator. | () | |
**kwargs | Additional keyword arguments specific to the integrator. | {} |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict with the same keys as the input |
Source code in torchebm/core/base_integrator.py
step(state, step_size, *args, **kwargs) abstractmethod ¶
Advance the dynamical state by one integrator application.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping containing required tensors (e.g., {'x': ..., 'p': ...}). | required |
step_size | Tensor | Step size for the integration. | required |
*args | Additional positional arguments specific to the integrator. | () | |
**kwargs | Additional keyword arguments specific to the integrator. | {} |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict with the same keys as the input |
Source code in torchebm/core/base_integrator.py
BaseInterpolant ¶
Bases: ABC
Abstract base class for stochastic interpolants.
An interpolant defines a conditional probability path between a source distribution (typically Gaussian noise) and a target distribution (data).
The interpolation is parameterized as:
where \(x_0 \sim \mathcal{N}(0, I)\) and \(x_1 \sim p_{\text{data}}\).
Subclasses must implement compute_alpha_t and compute_sigma_t.
Source code in torchebm/core/base_interpolant.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | |
compute_alpha_t(t) abstractmethod ¶
Compute the data coefficient \(\alpha(t)\) and its time derivative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor of shape (batch_size, ...). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (\(\alpha(t)\), \(\dot{\alpha}(t)\)). |
Source code in torchebm/core/base_interpolant.py
compute_d_alpha_alpha_ratio_t(t) ¶
Compute the ratio \(\dot{\alpha}(t) / \alpha(t)\) for numerical stability.
This method can be overridden for better numerical precision.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The ratio tensor. |
Source code in torchebm/core/base_interpolant.py
compute_diffusion(x, t, form='SBDM', norm=1.0) ¶
Compute diffusion coefficient for SDE sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
form | str | Diffusion form. Choices: - 'constant': Constant diffusion - 'SBDM': Score-based diffusion matching - 'sigma': Proportional to noise schedule - 'linear': Linear decay - 'decreasing': Faster decay towards t=1 - 'increasing-decreasing': Peak at midpoint | 'SBDM' |
norm | float | Scaling factor for diffusion. | 1.0 |
Returns:
| Type | Description |
|---|---|
Tensor | Diffusion coefficient tensor. |
Source code in torchebm/core/base_interpolant.py
compute_drift(x, t) ¶
Compute drift coefficients for score-based parameterization.
For the probability flow ODE in score parameterization: dx = [-drift_mean + drift_var * score] dt
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (drift_mean, drift_var) for score parameterization. |
Source code in torchebm/core/base_interpolant.py
compute_sigma_t(t) abstractmethod ¶
Compute the noise coefficient \(\sigma(t)\) and its time derivative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor of shape (batch_size, ...). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (\(\sigma(t)\), \(\dot{\sigma}(t)\)). |
Source code in torchebm/core/base_interpolant.py
interpolate(x0, x1, t) ¶
Compute the interpolated sample \(x_t\) and conditional velocity \(u_t\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x0 | Tensor | Noise samples of shape (batch_size, ...). | required |
x1 | Tensor | Data samples of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (x_t, u_t) where: - x_t = α(t) x₁ + σ(t) x₀ - u_t = α̇(t) x₁ + σ̇(t) x₀ |
Source code in torchebm/core/base_interpolant.py
score_to_velocity(score, x, t) ¶
Convert score prediction to velocity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
score | Tensor | Predicted score of shape (batch_size, ...). | required |
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tensor | Velocity tensor of shape (batch_size, ...). |
Source code in torchebm/core/base_interpolant.py
velocity_to_noise(velocity, x, t) ¶
Convert velocity prediction to noise prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
velocity | Tensor | Predicted velocity of shape (batch_size, ...). | required |
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tensor | Noise tensor of shape (batch_size, ...). |
Source code in torchebm/core/base_interpolant.py
velocity_to_score(velocity, x, t) ¶
Convert velocity prediction to score.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
velocity | Tensor | Predicted velocity of shape (batch_size, ...). | required |
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tensor | Score tensor of shape (batch_size, ...). |
Source code in torchebm/core/base_interpolant.py
BaseLoss ¶
Bases: DeviceMixin, Module, ABC
Abstract base class for loss functions used in energy-based models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dtype | dtype | Data type for computations. | float32 |
device | Optional[Union[str, device]] | Device for computations. | None |
use_mixed_precision | bool | Whether to use mixed precision training. | False |
clip_value | Optional[float] | Optional value to clamp the loss. | None |
Source code in torchebm/core/base_loss.py
__call__(x, *args, **kwargs) ¶
Calls the forward method of the loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The computed loss value. |
Source code in torchebm/core/base_loss.py
__init__(dtype=torch.float32, device=None, use_mixed_precision=False, clip_value=None, *args, **kwargs) ¶
Initialize the base loss class.
Source code in torchebm/core/base_loss.py
__repr__() ¶
__str__() ¶
forward(x, *args, **kwargs) abstractmethod ¶
Computes the loss value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor from the target distribution. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The computed scalar loss value. |
Source code in torchebm/core/base_loss.py
BaseModel ¶
Bases: DeviceMixin, Module, ABC
Abstract base class for energy-based models (EBMs).
This class provides a unified interface for defining EBMs, which represent the unnormalized negative log-likelihood of a probability distribution. It supports both analytical models and trainable neural networks.
Subclasses must implement the forward(x) method and can optionally override the gradient(x) method for analytical gradients.
Source code in torchebm/core/base_model.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
__call__(x, *args, **kwargs) ¶
Alias for the forward method for standard PyTorch module usage.
Source code in torchebm/core/base_model.py
__init__(dtype=torch.float32, use_mixed_precision=False, *args, **kwargs) ¶
Initializes the BaseModel base class.
Source code in torchebm/core/base_model.py
forward(x) abstractmethod ¶
Computes the scalar energy value for each input sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor of shape (batch_size, *input_dims). | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: Tensor of scalar energy values with shape (batch_size,). |
Source code in torchebm/core/base_model.py
gradient(x) ¶
Computes the gradient of the energy function with respect to the input, \(\nabla_x E(x)\).
This default implementation uses torch.autograd. Subclasses can override it for analytical gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor of shape (batch_size, *input_dims). | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: Gradient tensor of the same shape as |
Source code in torchebm/core/base_model.py
BaseRungeKuttaIntegrator ¶
Bases: BaseIntegrator
Abstract base class for explicit Runge-Kutta ODE integrators.
Subclasses define a Butcher tableau via the abstract properties tableau_a, tableau_b, and tableau_c and automatically inherit generic stepping and integration logic.
For an \(s\)-stage explicit RK method the update reads
Adaptive step-size control is available automatically for subclasses that define error_weights and order. When adaptive=True is passed to integrate (or left as None for auto-detection), the integrator uses an embedded error pair to control the step size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device | Optional[device] | Device for computations. | None |
dtype | Optional[dtype] | Data type for computations. | None |
atol | float | Absolute tolerance for adaptive stepping. | 1e-06 |
rtol | float | Relative tolerance for adaptive stepping. | 0.001 |
max_steps | int | Maximum number of steps (accepted + rejected) before raising | 10000 |
safety | float | Safety factor for step-size adjustment (< 1). | 0.9 |
min_factor | float | Minimum step-size shrink factor. | 0.2 |
max_factor | float | Maximum step-size growth factor. | 10.0 |
max_step_size | float | Maximum absolute step size allowed during adaptive integration. Defaults to | float('inf') |
norm | Optional[Callable[[Tensor], Tensor]] | Callable | None |
Example
Source code in torchebm/core/base_integrator.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 | |
error_weights property ¶
Error estimation weights \(e_i = b_i - \hat{b}_i\).
Return None (the default) for methods without an embedded pair. For FSAL methods the tuple has n_stages + 1 entries; for non-FSAL methods it has n_stages entries.
fsal property ¶
Whether the method has the First Same As Last property.
When True the integrator evaluates one extra stage at the accepted solution and reuses it as the first stage of the next step, saving one drift evaluation per accepted step.
n_stages property ¶
Number of stages s in the method.
order property ¶
Order p of the higher-order solution.
Used in the step-size control exponent \(-1/p\). Return None (the default) for methods without adaptive support.
tableau_a abstractmethod property ¶
Lower-triangular RK matrix \(a_{ij}\).
tableau_a[i] contains coefficients \(a_{i0}, \ldots, a_{i,i-1}\). The first row is the empty tuple ().
tableau_b abstractmethod property ¶
Weights \(b_i\) used to combine stages into the final update.
tableau_c abstractmethod property ¶
Nodes \(c_i\) — time-fraction offsets for each stage evaluation.
integrate(state, step_size, n_steps, *, drift=None, t=None, adaptive=None, inference_mode=False) ¶
Integrate the state over a time interval (ODE).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping with key | required |
step_size | Tensor | Uniform step size (fixed mode) or initial step size (adaptive mode). | required |
n_steps | int | Number of integration steps (fixed mode) or, together with | required |
drift | Optional[Callable[[Tensor, Tensor], Tensor]] | Explicit drift callable | None |
t | Optional[Tensor] | 1-D time grid. Built from | None |
adaptive | Optional[bool] |
| None |
inference_mode | bool | When | False |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict |
Source code in torchebm/core/base_integrator.py
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 | |
step(state, step_size, *, drift=None, t=None) ¶
Advance the state by one deterministic RK step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping containing | required |
step_size | Tensor | Step size for the integration. | required |
drift | Optional[Callable[[Tensor, Tensor], Tensor]] | Explicit drift callable | None |
t | Optional[Tensor] | Current time tensor (batch,). | None |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict |
Source code in torchebm/core/base_integrator.py
BaseSDERungeKuttaIntegrator ¶
Bases: BaseRungeKuttaIntegrator
Runge-Kutta integrator with additive SDE noise.
Extends BaseRungeKuttaIntegrator to solve Ito SDEs of the form
The stochastic term is applied as an Euler-order additive correction after the deterministic RK update:
Because the noise is added independently of the RK stages, the strong convergence order is \(0.5\) (Euler--Maruyama level) regardless of the underlying RK scheme order. The higher-order RK tableau improves only the deterministic component.
When diffusion is omitted the integrator reduces to its parent ODE behaviour.
Source code in torchebm/core/base_integrator.py
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 | |
integrate(state, step_size, n_steps, *, drift=None, diffusion=None, noise_scale=None, t=None, adaptive=None, inference_mode=False) ¶
Integrate the state over a time interval (ODE or SDE).
When diffusion or noise_scale is provided the integration uses fixed-step SDE mode. Adaptive step-size control is available only for the ODE case (no diffusion).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping with key | required |
step_size | Tensor | Step size (fixed) or initial step size (adaptive). | required |
n_steps | int | Number of integration steps. | required |
drift | Optional[Callable[[Tensor, Tensor], Tensor]] | Explicit drift callable | None |
diffusion | Optional[Callable[[Tensor, Tensor], Tensor]] | Time-dependent diffusion callable | None |
noise_scale | Optional[Tensor] | Scalar whose square is used as \(D\) when | None |
t | Optional[Tensor] | 1-D time grid. | None |
adaptive | Optional[bool] |
| None |
inference_mode | bool | When | False |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict |
Source code in torchebm/core/base_integrator.py
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 | |
step(state, step_size, *, drift=None, diffusion=None, noise=None, noise_scale=None, t=None) ¶
Advance the state by one RK step with optional SDE noise.
The deterministic update uses the Butcher tableau defined by the subclass. When a diffusion coefficient is provided, additive Wiener noise is appended at Euler--Maruyama order (strong order \(0.5\)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state | Dict[str, Tensor] | Mapping containing | required |
step_size | Tensor | Step size for the integration. | required |
drift | Optional[Callable[[Tensor, Tensor], Tensor]] | Explicit drift callable | None |
diffusion | Optional[Tensor] | Diffusion coefficient \(D(x, t)\) tensor. | None |
noise | Optional[Tensor] | Pre-sampled noise tensor. When | None |
noise_scale | Optional[Tensor] | Scalar whose square is used as \(D\) when | None |
t | Optional[Tensor] | Current time tensor (batch,). | None |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor] | Updated state dict |
Source code in torchebm/core/base_integrator.py
BaseSampler ¶
Bases: DeviceMixin, Module, ABC
Abstract base class for samplers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | The model to sample from. For MCMC samplers, this is typically a | required |
dtype | dtype | The data type for computations. | float32 |
device | Optional[Union[str, device]] | The device for computations. | None |
use_mixed_precision | bool | Whether to use mixed-precision for sampling. | False |
Source code in torchebm/core/base_sampler.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | |
apply_mixed_precision(func) ¶
A decorator to apply the mixed precision context to a method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func | The function to wrap. | required |
Returns:
| Type | Description |
|---|---|
| The wrapped function. |
Source code in torchebm/core/base_sampler.py
get_scheduled_value(name) ¶
Gets the current value for a scheduled parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the scheduled parameter. | required |
Returns:
| Name | Type | Description |
|---|---|---|
float | float | The current value of the parameter. |
Raises:
| Type | Description |
|---|---|
KeyError | If no scheduler is registered for the parameter. |
Source code in torchebm/core/base_sampler.py
get_schedulers() ¶
Gets all registered schedulers.
Returns:
| Type | Description |
|---|---|
Dict[str, BaseScheduler] | Dict[str, BaseScheduler]: A dictionary mapping parameter names to their schedulers. |
register_scheduler(name, scheduler) ¶
Registers a parameter scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the parameter to schedule. | required |
scheduler | BaseScheduler | The scheduler instance. | required |
Source code in torchebm/core/base_sampler.py
reset_schedulers() ¶
sample(x=None, dim=10, n_steps=100, n_samples=1, thin=1, return_trajectory=False, return_diagnostics=False, *args, **kwargs) abstractmethod ¶
Runs the sampling process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Optional[Tensor] | The initial state to start sampling from. | None |
dim | int | The dimension of the state space. | 10 |
n_steps | int | The number of MCMC steps to perform. | 100 |
n_samples | int | The number of samples to generate. | 1 |
thin | int | The thinning factor for samples (currently not supported). | 1 |
return_trajectory | bool | Whether to return the full trajectory of the samples. | False |
return_diagnostics | bool | Whether to return diagnostics of the sampling process. | False |
Returns:
| Type | Description |
|---|---|
Union[Tensor, Tuple[Tensor, List[dict]]] | Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]: - A tensor of samples from the model. - If |
Source code in torchebm/core/base_sampler.py
step_schedulers() ¶
Advances all schedulers by one step.
Returns:
| Type | Description |
|---|---|
Dict[str, float] | Dict[str, float]: A dictionary mapping parameter names to their updated values. |
Source code in torchebm/core/base_sampler.py
to(*args, **kwargs) ¶
Moves the sampler and its components to the specified device and/or dtype.
Source code in torchebm/core/base_sampler.py
BaseScheduler ¶
Bases: ABC
Abstract base class for parameter schedulers.
This class provides the foundation for all parameter scheduling strategies in TorchEBM. Schedulers are used to dynamically adjust parameters such as step sizes, noise scales, learning rates, and other hyperparameters during training or sampling processes.
The scheduler maintains an internal step counter and computes parameter values based on the current step. Subclasses must implement the _compute_value method to define the specific scheduling strategy.
Mathematical Foundation
A scheduler defines a function \(f: \mathbb{N} \to \mathbb{R}\) that maps step numbers to parameter values:
where \(t\) is the current step count and \(v(t)\) is the parameter value at step \(t\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Initial parameter value at step 0. | required |
Attributes:
| Name | Type | Description |
|---|---|---|
start_value | float | The initial parameter value. |
current_value | float | The current parameter value. |
step_count | int | Number of steps taken since initialization or last reset. |
Creating a Custom Scheduler
State Management
Source code in torchebm/core/base_scheduler.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | |
__init__(start_value) ¶
Initialize the base scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Initial parameter value. Must be a finite number. | required |
Raises:
| Type | Description |
|---|---|
TypeError | If start_value is not a float or int. |
Source code in torchebm/core/base_scheduler.py
get_value() ¶
Get the current parameter value without advancing the scheduler.
This method returns the current parameter value without modifying the scheduler's internal state. Use this when you need to query the current value without stepping.
Returns:
| Name | Type | Description |
|---|---|---|
float | float | The current parameter value. |
Query Current Value
Source code in torchebm/core/base_scheduler.py
load_state_dict(state_dict) ¶
Load the scheduler's state from a dictionary.
This method restores the scheduler's internal state from a dictionary previously created by state_dict(). This is useful for resuming training or sampling from a checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_dict | Dict[str, Any] | Dictionary containing the scheduler state. Should be an object returned from a call to | required |
State Restoration
Source code in torchebm/core/base_scheduler.py
reset() ¶
Reset the scheduler to its initial state.
This method resets both the step counter and current value to their initial states, effectively restarting the scheduling process.
Reset Example
Source code in torchebm/core/base_scheduler.py
state_dict() ¶
Return the state of the scheduler as a dictionary.
This method returns a dictionary containing all the scheduler's internal state, which can be used to save and restore the scheduler's state.
Returns:
| Type | Description |
|---|---|
Dict[str, Any] | Dict[str, Any]: Dictionary containing the scheduler's state. |
State Management
Source code in torchebm/core/base_scheduler.py
step() ¶
Advance the scheduler by one step and return the new parameter value.
This method increments the internal step counter and computes the new parameter value using the scheduler's strategy. The computed value becomes the new current value.
Returns:
| Name | Type | Description |
|---|---|---|
float | float | The new parameter value after stepping. |
Basic Usage
Source code in torchebm/core/base_scheduler.py
BaseScoreMatching ¶
Bases: BaseLoss
Abstract base class for Score Matching based loss functions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | BaseModel | The energy-based model to be trained. | required |
noise_scale | float | The scale of noise for perturbation in denoising variants. | 0.01 |
regularization_strength | float | The coefficient for regularization terms. | 0.0 |
use_autograd | bool | Whether to use | True |
hutchinson_samples | int | The number of random samples for Hutchinson's trick. | 1 |
custom_regularization | Optional[Callable] | An optional function for custom regularization. | None |
use_mixed_precision | bool | Whether to use mixed precision training. | False |
clip_value | Optional[float] | Optional value to clamp the loss. | None |
Source code in torchebm/core/base_loss.py
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 | |
__call__(x, *args, **kwargs) ¶
Calls the forward method of the loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The computed loss. |
Source code in torchebm/core/base_loss.py
__repr__() ¶
__str__() ¶
add_regularization(loss, x, custom_reg_fn=None, reg_strength=None) ¶
Adds regularization terms to the loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss | Tensor | The current loss value. | required |
x | Tensor | The input tensor. | required |
custom_reg_fn | Optional[Callable] | An optional custom regularization function. | None |
reg_strength | Optional[float] | An optional regularization strength. | None |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The loss with the regularization term added. |
Source code in torchebm/core/base_loss.py
compute_loss(x, *args, **kwargs) abstractmethod ¶
Computes the specific score matching loss variant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The specific score matching loss. |
Source code in torchebm/core/base_loss.py
compute_score(x, noise=None) ¶
1 | |
abla_x E(x)).
1 2 3 4 5 6 | |
Source code in torchebm/core/base_loss.py
forward(x, *args, **kwargs) abstractmethod ¶
Computes the score matching loss given input data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The computed score matching loss. |
Source code in torchebm/core/base_loss.py
perturb_data(x) ¶
Perturbs the input data with Gaussian noise for denoising variants.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple[torch.Tensor, torch.Tensor]: A tuple containing the perturbed data and the noise that was added. |
Source code in torchebm/core/base_loss.py
ConstantScheduler ¶
Bases: BaseScheduler
Scheduler that maintains a constant parameter value.
This scheduler returns the same value at every step, effectively providing no scheduling. It's useful as a baseline or when you want to disable scheduling for certain parameters while keeping the scheduler interface.
Mathematical Formula
where \(v_0\) is the start_value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | The constant value to maintain. | required |
Basic Usage
Using with Samplers
Source code in torchebm/core/base_scheduler.py
CosineScheduler ¶
Bases: BaseScheduler
Scheduler with cosine annealing.
This scheduler implements cosine annealing, which provides a smooth transition from the start value to the end value following a cosine curve. Cosine annealing is popular in deep learning as it provides fast initial decay followed by slower decay, which can help with convergence.
Mathematical Formula
where:
- \(v_0\) is the start_value
- \(v_{end}\) is the end_value
- \(T\) is n_steps
- \(t\) is the current step count
Cosine Curve Properties
The cosine function creates a smooth S-shaped curve that starts with rapid decay and gradually slows down as it approaches the end value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Starting parameter value. | required |
end_value | float | Target parameter value. | required |
n_steps | int | Number of steps to reach the final value. | required |
Raises:
| Type | Description |
|---|---|
ValueError | If n_steps is not positive. |
Step Size Annealing
Learning Rate Scheduling
Noise Scale Annealing
Source code in torchebm/core/base_scheduler.py
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 | |
__init__(start_value, end_value, n_steps) ¶
Initialize the cosine scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Starting parameter value. | required |
end_value | float | Target parameter value. | required |
n_steps | int | Number of steps to reach the final value. | required |
Raises:
| Type | Description |
|---|---|
ValueError | If n_steps is not positive. |
Source code in torchebm/core/base_scheduler.py
DeviceMixin ¶
A mixin for consistent device and dtype management across all modules.
This should be inherited by all classes that are sensitive to device or dtype.
Source code in torchebm/core/device_mixin.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
autocast_context() ¶
Returns a torch.cuda.amp.autocast context manager if mixed precision is enabled, otherwise a nullcontext.
Source code in torchebm/core/device_mixin.py
safe_to(obj, device=None, dtype=None) staticmethod ¶
Safely moves an object to a device and/or dtype, if it supports the .to() method.
Source code in torchebm/core/device_mixin.py
setup_mixed_precision(use_mixed_precision) ¶
Configures mixed precision settings.
Source code in torchebm/core/device_mixin.py
to(*args, **kwargs) ¶
Override to() to update internal device tracking.
Source code in torchebm/core/device_mixin.py
DoubleWellModel ¶
Bases: BaseModel
Energy-based model for a double-well potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
barrier_height | float | The height of the energy barrier between the wells. | 2.0 |
b | float | The position of the wells (default is 1.0, creating wells at ±1). | 1.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the double well energy: \(h \sum_{i=1}^{n} (x_i^2 - b^2)^2\).
Source code in torchebm/core/base_model.py
ExponentialDecayScheduler ¶
Bases: BaseScheduler
Scheduler with exponential decay.
This scheduler implements exponential decay of the parameter value according to: \(v(t) = \max(v_{min}, v_0 \times \gamma^t)\)
Exponential decay is commonly used for step sizes in optimization and sampling algorithms, as it provides rapid initial decay that slows down over time, allowing for both exploration and convergence.
Mathematical Formula
where:
- \(v_0\) is the start_value
- \(\gamma\) is the decay_rate \((0 < \gamma \leq 1)\)
- \(t\) is the step count
- \(v_{min}\) is the min_value (lower bound)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Initial parameter value. | required |
decay_rate | float | Decay factor applied at each step. Must be in (0, 1]. | required |
min_value | float | Minimum value to clamp the result. Defaults to 0.0. | 0.0 |
Raises:
| Type | Description |
|---|---|
ValueError | If decay_rate is not in (0, 1] or min_value is negative. |
Basic Exponential Decay
Training Loop Integration
Decay Rate Selection
- Aggressive decay: Use smaller decay_rate (e.g., 0.5)
- Gentle decay: Use larger decay_rate (e.g., 0.99)
Source code in torchebm/core/base_scheduler.py
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 | |
__init__(start_value, decay_rate, min_value=0.0) ¶
Initialize the exponential decay scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Initial parameter value. | required |
decay_rate | float | Decay factor applied at each step. Must be in (0, 1]. | required |
min_value | float | Minimum value to clamp the result. Defaults to 0.0. | 0.0 |
Raises:
| Type | Description |
|---|---|
ValueError | If decay_rate is not in (0, 1] or min_value is negative. |
Source code in torchebm/core/base_scheduler.py
GaussianModel ¶
Bases: BaseModel
Energy-based model for a Gaussian distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean | Tensor | The mean vector (μ) of the Gaussian distribution. | required |
cov | Tensor | The covariance matrix (Σ) of the Gaussian distribution. | required |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Gaussian energy: \(E(x) = \frac{1}{2} (x - \mu)^{\top} \Sigma^{-1} (x - \mu)\).
Source code in torchebm/core/base_model.py
HarmonicModel ¶
Bases: BaseModel
Energy-based model for a harmonic oscillator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
k | float | The spring constant. | 1.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the harmonic oscillator energy: \(\frac{1}{2} k \sum_{i=1}^{n} x_i^{2}\).
LinearScheduler ¶
Bases: BaseScheduler
Scheduler with linear interpolation between start and end values.
This scheduler linearly interpolates between a start value and an end value over a specified number of steps. After reaching the end value, it remains constant. Linear scheduling is useful when you want predictable, uniform changes in parameter values.
Mathematical Formula
where:
- \(v_0\) is the start_value
- \(v_{end}\) is the end_value
- \(T\) is n_steps
- \(t\) is the current step count
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Starting parameter value. | required |
end_value | float | Target parameter value. | required |
n_steps | int | Number of steps to reach the final value. | required |
Raises:
| Type | Description |
|---|---|
ValueError | If n_steps is not positive. |
Linear Decay
Warmup Strategy
MCMC Integration
Source code in torchebm/core/base_scheduler.py
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 | |
__init__(start_value, end_value, n_steps) ¶
Initialize the linear scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start_value | float | Starting parameter value. | required |
end_value | float | Target parameter value. | required |
n_steps | int | Number of steps to reach the final value. | required |
Raises:
| Type | Description |
|---|---|
ValueError | If n_steps is not positive. |
Source code in torchebm/core/base_scheduler.py
RastriginModel ¶
Bases: BaseModel
Energy-based model for the Rastrigin function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 10.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Rastrigin energy.
Source code in torchebm/core/base_model.py
RosenbrockModel ¶
Bases: BaseModel
Energy-based model for the Rosenbrock function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 1.0 |
b | float | The | 100.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Rosenbrock energy: \(\sum_{i=1}^{n-1} \left[ b(x_{i+1} - x_i^2)^2 + (a - x_i)^2 \right]\).
Source code in torchebm/core/base_model.py
expand_t_like_x(t, x) ¶
Expand time tensor to match spatial dimensions of x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor of shape (batch_size,). | required |
x | Tensor | Reference tensor of shape (batch_size, ...). | required |
Returns:
| Type | Description |
|---|---|
Tensor | Time tensor expanded to shape (batch_size, 1, 1, ...). |
Source code in torchebm/core/base_interpolant.py
normalize_device(device) ¶
Normalizes the device identifier for consistent usage.
Converts string identifiers to torch.device objects and defaults to 'cuda' if available, otherwise 'cpu'.