Coverage for physped/core/piecewise_potential.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-01 09:28 +0000

1"""Module for the PiecewisePotential class.""" 

2 

3import logging 

4 

5import numpy as np 

6from omegaconf import DictConfig 

7 

8from physped.core.distribution_approximator import DistApproximation 

9from physped.core.lattice import Lattice 

10 

11log = logging.getLogger(__name__) 

12 

13 

14class PiecewisePotential: 

15 def __init__( 

16 self, lattice: Lattice, dist_approximation: DistApproximation 

17 ): 

18 """A class for the piecewise potential. 

19 

20 Creates the lattice to discretize the slow dynamics and fit the 

21 potential. 

22 

23 Args: 

24 bins: A dictionary containing the bin edges for each dimension. 

25 """ 

26 self.lattice = lattice 

27 self.dist_approximation = dist_approximation 

28 # TODO: Initialize histograms with int64 instead 

29 self.histogram = np.zeros(self.lattice.shape, dtype=np.float64) 

30 self.histogram_slow = np.zeros(self.lattice.shape, dtype=np.float64) 

31 self.initialize_parametrization() 

32 

33 def __repr__(self): 

34 return ( 

35 f"PiecewisePotential with dimensions {self.lattice.dimensions}" 

36 f", fit dimensions {self.dist_approximation.fit_dimensions}," 

37 f"and parameters {self.dist_approximation.fit_parameters}" 

38 ) 

39 

40 def initialize_parametrization(self): 

41 """Initialize the potential parametrization. 

42 

43 We initialize the parametrization with the following shape: 

44 (lattice_shape, len(fit_dimensions), len(fit_parameters)) 

45 Such that the potential is parameterized in each lattice site 

46 for every fit dimension by the number of free fit parameters. 

47 """ 

48 shape_of_the_potential = self.lattice.shape + ( 

49 len(self.dist_approximation.fit_dimensions), 

50 len(self.dist_approximation.fit_parameters), 

51 ) 

52 self.parametrization = np.zeros(shape_of_the_potential) * np.nan 

53 

54 def reparametrize_to_curvature(self, config: DictConfig): 

55 """Reparametrize the potential. 

56 

57 From (mu, var) to (mu, curvature). 

58 Implements equations 15 and 16 from the paper. 

59 

60 Args: 

61 config: The configuration. 

62 

63 Raises: 

64 ValueError: If the fit parameters are not mu and sigma 

65 """ 

66 if self.dist_approximation.fit_parameters != ("mu", "sigma"): 

67 raise ValueError("The fit parameters should be mu and sigma.") 

68 

69 var = config.params.model.sigma**2 

70 xvar = self.parametrization[..., 0, 1] 

71 yvar = self.parametrization[..., 1, 1] 

72 uvar = self.parametrization[..., 2, 1] 

73 vvar = self.parametrization[..., 3, 1] 

74 

75 self.parametrization[..., 0, 1] = uvar / (2 * xvar) 

76 self.parametrization[..., 1, 1] = vvar / (2 * yvar) 

77 self.parametrization[..., 2, 1] = var / (4 * uvar) 

78 self.parametrization[..., 3, 1] = var / (4 * vvar) 

79 self.dist_approximation.fit_parameters = ["mu", "curvature"]