Coverage for physped/core/parametrize_potential.py: 42%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-01 09:28 +0000

1"""Infer force fields from trajectories.""" 

2 

3import copy 

4import logging 

5 

6# from pathlib import Path 

7from typing import List 

8 

9import numpy as np 

10import pandas as pd 

11from omegaconf import DictConfig 

12from scipy.stats import norm 

13 

14from physped.core.digitizers import digitize_coordinates_to_lattice 

15from physped.core.distribution_approximator import GaussianApproximation 

16from physped.core.lattice import Lattice 

17from physped.core.piecewise_potential import PiecewisePotential 

18 

19# from physped.io.readers import read_piecewise_potential_from_file 

20from physped.utils.functions import ( 

21 compose_functions, 

22 periodic_angular_conditions, 

23 weighted_mean_of_two_matrices, 

24) 

25 

26log = logging.getLogger(__name__) 

27 

28 

29def learn_potential_from_trajectories( 

30 trajectories: pd.DataFrame, config: DictConfig 

31) -> PiecewisePotential: 

32 """ 

33 Convert trajectories to a grid of histograms and parameters. 

34 

35 Args: 

36 trajectories: A DataFrame of trajectories. 

37 grid_bins: A dictionary of grid values for each dimension. 

38 

39 Returns: 

40 A dictionary of DiscreteGrid objects for storing histograms and 

41 parameters. 

42 """ 

43 log.info("Start learning the piecewise potential") 

44 lattice = Lattice(config.params.grid.bins) 

45 dist_approximation = GaussianApproximation() 

46 piecewise_potential = PiecewisePotential(lattice, dist_approximation) 

47 

48 trajectories = prepare_trajectories_for_lattice_parametrization( 

49 trajectories, lattice=lattice 

50 ) 

51 

52 piecewise_potential.histogram = add_trajectories_to_histogram( 

53 piecewise_potential.histogram, trajectories, "fast_grid_indices" 

54 ) 

55 piecewise_potential.histogram_slow = add_trajectories_to_histogram( 

56 piecewise_potential.histogram_slow, trajectories, "slow_grid_indices" 

57 ) 

58 

59 piecewise_potential.parametrization = parameterize_trajectories( 

60 piecewise_potential.parametrization, trajectories, config 

61 ) 

62 

63 piecewise_potential.reparametrize_to_curvature(config) 

64 return piecewise_potential 

65 

66 

67def apply_periodic_angular_conditions( 

68 trajectories: pd.DataFrame, lattice: Lattice 

69) -> pd.DataFrame: 

70 """Apply periodic angular conditions to the trajectories. 

71 

72 This function makes sure that the angles are within the range of the 

73 angular lattice bins. 

74 

75 Args: 

76 trajectories: the trajectory data set. 

77 lattice: the lattice object 

78 

79 Returns: 

80 The trajectories with the angular conditions applied. 

81 """ 

82 theta_cols = [col for col in trajectories.columns if "theta" in col] 

83 for col in theta_cols: 

84 trajectories[col] = periodic_angular_conditions( 

85 trajectories[col], lattice.bins["theta"] 

86 ) 

87 log.info("Periodic angular conditions applied to columns %s", theta_cols) 

88 return trajectories 

89 

90 

91def digitize_trajectories_to_grid( 

92 trajectories: pd.DataFrame, lattice: Lattice 

93) -> pd.DataFrame: 

94 """Digitize trajectories to a lattice. 

95 

96 Adds a column to the dataframe with the trajectories that contains 

97 the slow indices 

98 

99 Args: 

100 grid_bins: The bins which define the lattice. 

101 trajectories: The trajectories to digitize. 

102 

103 Returns: 

104 The trajectories with an extra column for the slow indices. 

105 """ 

106 indices = {} 

107 for obs, dynamics in [ 

108 (obs, dynamics) 

109 for obs in lattice.bins.keys() 

110 for dynamics in ["f", "s"] 

111 ]: 

112 if obs == "k": 

113 dobs = obs 

114 else: 

115 dobs = obs + dynamics 

116 inds = digitize_coordinates_to_lattice( 

117 trajectories[dobs], lattice.bins[obs] 

118 ) 

119 indices[dobs] = inds 

120 

121 indices["thetaf"] = np.where(indices["rf"] == 0, 0, indices["thetaf"]) 

122 indices["thetas"] = np.where(indices["rs"] == 0, 0, indices["thetas"]) 

123 

124 trajectories["fast_grid_indices"] = list( 

125 zip( 

126 indices["xf"], 

127 indices["yf"], 

128 indices["rf"], 

129 indices["thetaf"], 

130 indices["k"], 

131 ) 

132 ) 

133 trajectories["slow_grid_indices"] = list( 

134 zip( 

135 indices["xs"], 

136 indices["ys"], 

137 indices["rs"], 

138 indices["thetas"], 

139 indices["k"], 

140 ) 

141 ) 

142 return trajectories 

143 

144 

145prepare_trajectories_for_lattice_parametrization = compose_functions( 

146 apply_periodic_angular_conditions, digitize_trajectories_to_grid 

147) 

148 

149 

150def add_trajectories_to_histogram( 

151 histogram: np.ndarray, trajectories: pd.DataFrame, groupbyindex: str 

152) -> np.ndarray: 

153 """Add trajectories to a histogram. 

154 

155 Args: 

156 histogram: The histogram to add the trajectories to. 

157 trajectories: The trajectories to add to the histogram. 

158 

159 Returns: 

160 The updated histogram. 

161 """ 

162 for grid_index, group in trajectories.groupby(groupbyindex): 

163 histogram[grid_index] += len(group) 

164 return histogram 

165 

166 

167def parameterize_trajectories( 

168 parametrization: np.ndarray, trajectories: pd.DataFrame, config: DictConfig 

169): 

170 """Fit trajectories to the lattice. 

171 

172 Fit the fast dynamics conditioned to the slow dynamics. 

173 

174 Args: 

175 parametrization: The initialized, empty, parametrization matrix. 

176 trajectories: The trajectories to fit. 

177 config: The configuration parameters. 

178 

179 Returns: 

180 The updated parametrization matrix. 

181 """ 

182 fit_output = ( 

183 trajectories.groupby("slow_grid_indices") 

184 .apply(fit_probability_distributions, config=config) 

185 .dropna() 

186 .to_dict() 

187 ) 

188 for key, value in fit_output.items(): 

189 parametrization[key[0], key[1], key[2], key[3], key[4], :, :] = value 

190 log.info("Finished learning piecewise potential from trajectories.") 

191 return parametrization 

192 

193 

194def calculate_position_based_emperic_potential( 

195 histogram_slow, config: DictConfig 

196): 

197 position_counts = np.nansum(histogram_slow, axis=(2, 3, 4)) 

198 position_counts = np.where( 

199 position_counts < config.params.model.minimum_fitting_threshold, 

200 np.nan, 

201 position_counts, 

202 ) 

203 A = 0.02 # TODO: Move to config 

204 position_based_emperic_potential = A * ( 

205 -np.log(position_counts) + np.log(np.nansum(histogram_slow)) 

206 ) 

207 return position_based_emperic_potential 

208 

209 

210def accumulate_grids( 

211 cummulative_grids: PiecewisePotential, grids_to_add: PiecewisePotential 

212) -> PiecewisePotential: 

213 """Accumulate grids by taking a weighted mean of the fit parameters. 

214 

215 The goal of this function is to sum PiecewisePotential objects. 

216 

217 Args: 

218 cummulative_grids: The cumulative grids to add to. 

219 grids_to_add: The grids to add to the cumulative grids. 

220 

221 Returns: 

222 The updated cumulative grids. 

223 """ 

224 # ! WARNING: This function needs to be tested. Seems to have a bug. 

225 # ! Perhaps this needs to be a dunder class method i.e. __add__ 

226 for p, _ in enumerate( 

227 cummulative_grids.fit_param_names 

228 ): # Loop over all fit parameters 

229 # accumulate fit parameters 

230 cummulative_grids.fit_params[:, :, :, :, :, p] = ( 

231 weighted_mean_of_two_matrices( 

232 first_matrix=copy.deepcopy( 

233 cummulative_grids.fit_params[:, :, :, :, :, p] 

234 ), 

235 counts_first_matrix=copy.deepcopy(cummulative_grids.histogram), 

236 second_matrix=copy.deepcopy( 

237 grids_to_add.fit_params[:, :, :, :, :, p] 

238 ), 

239 counts_second_matrix=copy.deepcopy(grids_to_add.histogram), 

240 ) 

241 ) 

242 # accumlate histogram 

243 cummulative_grids.histogram += grids_to_add.histogram 

244 cummulative_grids.histogram_slow += grids_to_add.histogram_slow 

245 return cummulative_grids 

246 

247 

248def extract_submatrix( 

249 matrix: np.ndarray, slicing_indices: List[tuple] 

250) -> np.ndarray: 

251 """Extract a submatrix from a nd-matrix using periodic boundary conditions. 

252 

253 Periodicity is needed for the angular dimension. 

254 

255 Args: 

256 matrix: The input nd-matrix to slice. 

257 slicing_indices: A list of slice tuples for each dimension of the 

258 nd-matrix. 

259 

260 Returns: 

261 The submatrix. 

262 """ 

263 if any(slice[0] > slice[1] for slice in slicing_indices): 

264 raise ValueError("Slicing indices must be ascending.") 

265 

266 reshape_dimension = (-1,) + (1,) * (len(slicing_indices) - 1) 

267 slicing_indices = [ 

268 np.arange(*slice).reshape(np.roll(reshape_dimension, i)) 

269 % matrix.shape[i] 

270 for i, slice in enumerate(slicing_indices) 

271 ] 

272 return matrix[tuple(slicing_indices)] 

273 

274 

275def fit_probability_distributions( 

276 group: pd.DataFrame, config: DictConfig 

277) -> np.ndarray: 

278 """Fits a group of data points and returns the fit parameters. 

279 

280 Args: 

281 group: The group of data points to fit the normal distribution to. 

282 config: The configuration parameters. 

283 

284 Returns: 

285 A matrix containing the fit parameters. 

286 """ 

287 if len(group) < config.params.model.minimum_fitting_threshold: 

288 return np.nan # ? Can we do better if we have multiple files? 

289 fit_func = norm.fit # * Other functions can be implemented here 

290 fit_parameters = np.zeros((4, 2)) * np.nan 

291 for i, variable in enumerate(["xf", "yf", "uf", "vf"]): 

292 mu, std = fit_func( 

293 group[variable] 

294 ) # fit normal distribution to fast modes 

295 fit_parameters[i, :] = [ 

296 mu, 

297 std**2, 

298 ] # store mean and variance of normal distribution 

299 return fit_parameters 

300 

301 

302def get_grid_indices( 

303 piecewise_potential: PiecewisePotential, point: List[float] 

304) -> np.ndarray: 

305 """Given a point (xs, ys, thetas, rs), return the associated lattice 

306 indices. 

307 

308 This function is 4-dimensional. 

309 

310 If the radial velocity is in the lowest bin, the angular velocity is 

311 automatically also added to the lowest bin. In other words, the angular 

312 velocities are not discretized for low radial velocity. 

313 

314 Args: 

315 potential: The piecewise potential. 

316 point: A list with slow positions and velocities (xs, ys, thetas, rs). 

317 

318 Returns: 

319 A tuple of grid indices. 

320 """ 

321 # ! Write a test for this function 

322 indices = np.array([], dtype=int) 

323 for val, obs in zip(point, piecewise_potential.lattice.bins.keys()): 

324 grid = piecewise_potential.lattice.bins[obs] 

325 indices = np.append( 

326 indices, digitize_coordinates_to_lattice(val, grid) 

327 ) 

328 

329 # For r = 0 all theta are 0 

330 if indices[2] == 0: 

331 indices[3] = 0 # merge grid cells for low r_s 

332 

333 return indices 

334 

335 

336def get_boundary_coordinates_of_selection(bins, observable, values): 

337 """Return the grid bounds of a given observable and value.""" 

338 if observable == "theta": 

339 grid_sides = [ 

340 bins[(values[0] - 1) % (len(bins) - 1)], 

341 bins[(values[1]) % (len(bins) - 1)], 

342 ] 

343 else: 

344 grid_sides = [bins[values[0] - 1], bins[values[1]]] 

345 return grid_sides 

346 

347 

348def selection_to_bounds(bins, selection_coordinates, dimension): 

349 selection_grid_indices = digitize_coordinates_to_lattice( 

350 selection_coordinates, bins 

351 ) 

352 selection_boundary_coordinates = get_boundary_coordinates_of_selection( 

353 bins, dimension, selection_grid_indices 

354 ) 

355 return selection_boundary_coordinates 

356 

357 

358def make_grid_selection(piecewise_potential, selection): 

359 """Make selection.""" 

360 # TODO: Remove dependency on PiecewisePotential object 

361 grid_selection = {} 

362 for observable, value in selection.items(): 

363 print(observable, value) 

364 # for observable, value in zip(["x", "y", "r", "theta"], 

365 # [x, y, r, theta, k]): 

366 grid_selection[observable] = {} 

367 # grid = grid.grid_observable[observable] 

368 grid_bins = piecewise_potential.lattice.bins.get(observable) 

369 print(grid_bins) 

370 

371 if not value: # if None select full grid 

372 value = [grid_bins[0], grid_bins[-2]] 

373 elif isinstance(value, int): # if int select single value on grid 

374 value = [float(value), float(value)] 

375 elif isinstance(value, float): # if float select single value on grid 

376 value = [value, value] 

377 

378 grid_selection[observable]["selection"] = value 

379 grid_ids = digitize_coordinates_to_lattice(value, grid_bins) 

380 grid_selection[observable]["grid_ids"] = grid_ids 

381 grid_boundaries = get_boundary_coordinates_of_selection( 

382 grid_bins, observable, grid_ids 

383 ) 

384 grid_selection[observable]["bounds"] = grid_boundaries 

385 

386 if observable == "theta": 

387 while grid_boundaries[1] < grid_boundaries[0]: 

388 grid_boundaries[1] += 2 * np.pi 

389 

390 grid_selection[observable]["periodic_bounds"] = grid_boundaries 

391 return grid_selection