Coverage for physped/core/piecewise_potential.py: 97%
29 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-01 09:28 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-01 09:28 +0000
1"""Module for the PiecewisePotential class."""
3import logging
5import numpy as np
6from omegaconf import DictConfig
8from physped.core.distribution_approximator import DistApproximation
9from physped.core.lattice import Lattice
11log = logging.getLogger(__name__)
14class PiecewisePotential:
15 def __init__(
16 self, lattice: Lattice, dist_approximation: DistApproximation
17 ):
18 """A class for the piecewise potential.
20 Creates the lattice to discretize the slow dynamics and fit the
21 potential.
23 Args:
24 bins: A dictionary containing the bin edges for each dimension.
25 """
26 self.lattice = lattice
27 self.dist_approximation = dist_approximation
28 # TODO: Initialize histograms with int64 instead
29 self.histogram = np.zeros(self.lattice.shape, dtype=np.float64)
30 self.histogram_slow = np.zeros(self.lattice.shape, dtype=np.float64)
31 self.initialize_parametrization()
33 def __repr__(self):
34 return (
35 f"PiecewisePotential with dimensions {self.lattice.dimensions}"
36 f", fit dimensions {self.dist_approximation.fit_dimensions},"
37 f"and parameters {self.dist_approximation.fit_parameters}"
38 )
40 def initialize_parametrization(self):
41 """Initialize the potential parametrization.
43 We initialize the parametrization with the following shape:
44 (lattice_shape, len(fit_dimensions), len(fit_parameters))
45 Such that the potential is parameterized in each lattice site
46 for every fit dimension by the number of free fit parameters.
47 """
48 shape_of_the_potential = self.lattice.shape + (
49 len(self.dist_approximation.fit_dimensions),
50 len(self.dist_approximation.fit_parameters),
51 )
52 self.parametrization = np.zeros(shape_of_the_potential) * np.nan
54 def reparametrize_to_curvature(self, config: DictConfig):
55 """Reparametrize the potential.
57 From (mu, var) to (mu, curvature).
58 Implements equations 15 and 16 from the paper.
60 Args:
61 config: The configuration.
63 Raises:
64 ValueError: If the fit parameters are not mu and sigma
65 """
66 if self.dist_approximation.fit_parameters != ("mu", "sigma"):
67 raise ValueError("The fit parameters should be mu and sigma.")
69 var = config.params.model.sigma**2
70 xvar = self.parametrization[..., 0, 1]
71 yvar = self.parametrization[..., 1, 1]
72 uvar = self.parametrization[..., 2, 1]
73 vvar = self.parametrization[..., 3, 1]
75 self.parametrization[..., 0, 1] = uvar / (2 * xvar)
76 self.parametrization[..., 1, 1] = vvar / (2 * yvar)
77 self.parametrization[..., 2, 1] = var / (4 * uvar)
78 self.parametrization[..., 3, 1] = var / (4 * vvar)
79 self.dist_approximation.fit_parameters = ["mu", "curvature"]