Coverage for physped/core/parametrize_potential.py: 42%
121 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-01 09:28 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-01 09:28 +0000
1"""Infer force fields from trajectories."""
3import copy
4import logging
6# from pathlib import Path
7from typing import List
9import numpy as np
10import pandas as pd
11from omegaconf import DictConfig
12from scipy.stats import norm
14from physped.core.digitizers import digitize_coordinates_to_lattice
15from physped.core.distribution_approximator import GaussianApproximation
16from physped.core.lattice import Lattice
17from physped.core.piecewise_potential import PiecewisePotential
19# from physped.io.readers import read_piecewise_potential_from_file
20from physped.utils.functions import (
21 compose_functions,
22 periodic_angular_conditions,
23 weighted_mean_of_two_matrices,
24)
26log = logging.getLogger(__name__)
29def learn_potential_from_trajectories(
30 trajectories: pd.DataFrame, config: DictConfig
31) -> PiecewisePotential:
32 """
33 Convert trajectories to a grid of histograms and parameters.
35 Args:
36 trajectories: A DataFrame of trajectories.
37 grid_bins: A dictionary of grid values for each dimension.
39 Returns:
40 A dictionary of DiscreteGrid objects for storing histograms and
41 parameters.
42 """
43 log.info("Start learning the piecewise potential")
44 lattice = Lattice(config.params.grid.bins)
45 dist_approximation = GaussianApproximation()
46 piecewise_potential = PiecewisePotential(lattice, dist_approximation)
48 trajectories = prepare_trajectories_for_lattice_parametrization(
49 trajectories, lattice=lattice
50 )
52 piecewise_potential.histogram = add_trajectories_to_histogram(
53 piecewise_potential.histogram, trajectories, "fast_grid_indices"
54 )
55 piecewise_potential.histogram_slow = add_trajectories_to_histogram(
56 piecewise_potential.histogram_slow, trajectories, "slow_grid_indices"
57 )
59 piecewise_potential.parametrization = parameterize_trajectories(
60 piecewise_potential.parametrization, trajectories, config
61 )
63 piecewise_potential.reparametrize_to_curvature(config)
64 return piecewise_potential
67def apply_periodic_angular_conditions(
68 trajectories: pd.DataFrame, lattice: Lattice
69) -> pd.DataFrame:
70 """Apply periodic angular conditions to the trajectories.
72 This function makes sure that the angles are within the range of the
73 angular lattice bins.
75 Args:
76 trajectories: the trajectory data set.
77 lattice: the lattice object
79 Returns:
80 The trajectories with the angular conditions applied.
81 """
82 theta_cols = [col for col in trajectories.columns if "theta" in col]
83 for col in theta_cols:
84 trajectories[col] = periodic_angular_conditions(
85 trajectories[col], lattice.bins["theta"]
86 )
87 log.info("Periodic angular conditions applied to columns %s", theta_cols)
88 return trajectories
91def digitize_trajectories_to_grid(
92 trajectories: pd.DataFrame, lattice: Lattice
93) -> pd.DataFrame:
94 """Digitize trajectories to a lattice.
96 Adds a column to the dataframe with the trajectories that contains
97 the slow indices
99 Args:
100 grid_bins: The bins which define the lattice.
101 trajectories: The trajectories to digitize.
103 Returns:
104 The trajectories with an extra column for the slow indices.
105 """
106 indices = {}
107 for obs, dynamics in [
108 (obs, dynamics)
109 for obs in lattice.bins.keys()
110 for dynamics in ["f", "s"]
111 ]:
112 if obs == "k":
113 dobs = obs
114 else:
115 dobs = obs + dynamics
116 inds = digitize_coordinates_to_lattice(
117 trajectories[dobs], lattice.bins[obs]
118 )
119 indices[dobs] = inds
121 indices["thetaf"] = np.where(indices["rf"] == 0, 0, indices["thetaf"])
122 indices["thetas"] = np.where(indices["rs"] == 0, 0, indices["thetas"])
124 trajectories["fast_grid_indices"] = list(
125 zip(
126 indices["xf"],
127 indices["yf"],
128 indices["rf"],
129 indices["thetaf"],
130 indices["k"],
131 )
132 )
133 trajectories["slow_grid_indices"] = list(
134 zip(
135 indices["xs"],
136 indices["ys"],
137 indices["rs"],
138 indices["thetas"],
139 indices["k"],
140 )
141 )
142 return trajectories
145prepare_trajectories_for_lattice_parametrization = compose_functions(
146 apply_periodic_angular_conditions, digitize_trajectories_to_grid
147)
150def add_trajectories_to_histogram(
151 histogram: np.ndarray, trajectories: pd.DataFrame, groupbyindex: str
152) -> np.ndarray:
153 """Add trajectories to a histogram.
155 Args:
156 histogram: The histogram to add the trajectories to.
157 trajectories: The trajectories to add to the histogram.
159 Returns:
160 The updated histogram.
161 """
162 for grid_index, group in trajectories.groupby(groupbyindex):
163 histogram[grid_index] += len(group)
164 return histogram
167def parameterize_trajectories(
168 parametrization: np.ndarray, trajectories: pd.DataFrame, config: DictConfig
169):
170 """Fit trajectories to the lattice.
172 Fit the fast dynamics conditioned to the slow dynamics.
174 Args:
175 parametrization: The initialized, empty, parametrization matrix.
176 trajectories: The trajectories to fit.
177 config: The configuration parameters.
179 Returns:
180 The updated parametrization matrix.
181 """
182 fit_output = (
183 trajectories.groupby("slow_grid_indices")
184 .apply(fit_probability_distributions, config=config)
185 .dropna()
186 .to_dict()
187 )
188 for key, value in fit_output.items():
189 parametrization[key[0], key[1], key[2], key[3], key[4], :, :] = value
190 log.info("Finished learning piecewise potential from trajectories.")
191 return parametrization
194def calculate_position_based_emperic_potential(
195 histogram_slow, config: DictConfig
196):
197 position_counts = np.nansum(histogram_slow, axis=(2, 3, 4))
198 position_counts = np.where(
199 position_counts < config.params.model.minimum_fitting_threshold,
200 np.nan,
201 position_counts,
202 )
203 A = 0.02 # TODO: Move to config
204 position_based_emperic_potential = A * (
205 -np.log(position_counts) + np.log(np.nansum(histogram_slow))
206 )
207 return position_based_emperic_potential
210def accumulate_grids(
211 cummulative_grids: PiecewisePotential, grids_to_add: PiecewisePotential
212) -> PiecewisePotential:
213 """Accumulate grids by taking a weighted mean of the fit parameters.
215 The goal of this function is to sum PiecewisePotential objects.
217 Args:
218 cummulative_grids: The cumulative grids to add to.
219 grids_to_add: The grids to add to the cumulative grids.
221 Returns:
222 The updated cumulative grids.
223 """
224 # ! WARNING: This function needs to be tested. Seems to have a bug.
225 # ! Perhaps this needs to be a dunder class method i.e. __add__
226 for p, _ in enumerate(
227 cummulative_grids.fit_param_names
228 ): # Loop over all fit parameters
229 # accumulate fit parameters
230 cummulative_grids.fit_params[:, :, :, :, :, p] = (
231 weighted_mean_of_two_matrices(
232 first_matrix=copy.deepcopy(
233 cummulative_grids.fit_params[:, :, :, :, :, p]
234 ),
235 counts_first_matrix=copy.deepcopy(cummulative_grids.histogram),
236 second_matrix=copy.deepcopy(
237 grids_to_add.fit_params[:, :, :, :, :, p]
238 ),
239 counts_second_matrix=copy.deepcopy(grids_to_add.histogram),
240 )
241 )
242 # accumlate histogram
243 cummulative_grids.histogram += grids_to_add.histogram
244 cummulative_grids.histogram_slow += grids_to_add.histogram_slow
245 return cummulative_grids
248def extract_submatrix(
249 matrix: np.ndarray, slicing_indices: List[tuple]
250) -> np.ndarray:
251 """Extract a submatrix from a nd-matrix using periodic boundary conditions.
253 Periodicity is needed for the angular dimension.
255 Args:
256 matrix: The input nd-matrix to slice.
257 slicing_indices: A list of slice tuples for each dimension of the
258 nd-matrix.
260 Returns:
261 The submatrix.
262 """
263 if any(slice[0] > slice[1] for slice in slicing_indices):
264 raise ValueError("Slicing indices must be ascending.")
266 reshape_dimension = (-1,) + (1,) * (len(slicing_indices) - 1)
267 slicing_indices = [
268 np.arange(*slice).reshape(np.roll(reshape_dimension, i))
269 % matrix.shape[i]
270 for i, slice in enumerate(slicing_indices)
271 ]
272 return matrix[tuple(slicing_indices)]
275def fit_probability_distributions(
276 group: pd.DataFrame, config: DictConfig
277) -> np.ndarray:
278 """Fits a group of data points and returns the fit parameters.
280 Args:
281 group: The group of data points to fit the normal distribution to.
282 config: The configuration parameters.
284 Returns:
285 A matrix containing the fit parameters.
286 """
287 if len(group) < config.params.model.minimum_fitting_threshold:
288 return np.nan # ? Can we do better if we have multiple files?
289 fit_func = norm.fit # * Other functions can be implemented here
290 fit_parameters = np.zeros((4, 2)) * np.nan
291 for i, variable in enumerate(["xf", "yf", "uf", "vf"]):
292 mu, std = fit_func(
293 group[variable]
294 ) # fit normal distribution to fast modes
295 fit_parameters[i, :] = [
296 mu,
297 std**2,
298 ] # store mean and variance of normal distribution
299 return fit_parameters
302def get_grid_indices(
303 piecewise_potential: PiecewisePotential, point: List[float]
304) -> np.ndarray:
305 """Given a point (xs, ys, thetas, rs), return the associated lattice
306 indices.
308 This function is 4-dimensional.
310 If the radial velocity is in the lowest bin, the angular velocity is
311 automatically also added to the lowest bin. In other words, the angular
312 velocities are not discretized for low radial velocity.
314 Args:
315 potential: The piecewise potential.
316 point: A list with slow positions and velocities (xs, ys, thetas, rs).
318 Returns:
319 A tuple of grid indices.
320 """
321 # ! Write a test for this function
322 indices = np.array([], dtype=int)
323 for val, obs in zip(point, piecewise_potential.lattice.bins.keys()):
324 grid = piecewise_potential.lattice.bins[obs]
325 indices = np.append(
326 indices, digitize_coordinates_to_lattice(val, grid)
327 )
329 # For r = 0 all theta are 0
330 if indices[2] == 0:
331 indices[3] = 0 # merge grid cells for low r_s
333 return indices
336def get_boundary_coordinates_of_selection(bins, observable, values):
337 """Return the grid bounds of a given observable and value."""
338 if observable == "theta":
339 grid_sides = [
340 bins[(values[0] - 1) % (len(bins) - 1)],
341 bins[(values[1]) % (len(bins) - 1)],
342 ]
343 else:
344 grid_sides = [bins[values[0] - 1], bins[values[1]]]
345 return grid_sides
348def selection_to_bounds(bins, selection_coordinates, dimension):
349 selection_grid_indices = digitize_coordinates_to_lattice(
350 selection_coordinates, bins
351 )
352 selection_boundary_coordinates = get_boundary_coordinates_of_selection(
353 bins, dimension, selection_grid_indices
354 )
355 return selection_boundary_coordinates
358def make_grid_selection(piecewise_potential, selection):
359 """Make selection."""
360 # TODO: Remove dependency on PiecewisePotential object
361 grid_selection = {}
362 for observable, value in selection.items():
363 print(observable, value)
364 # for observable, value in zip(["x", "y", "r", "theta"],
365 # [x, y, r, theta, k]):
366 grid_selection[observable] = {}
367 # grid = grid.grid_observable[observable]
368 grid_bins = piecewise_potential.lattice.bins.get(observable)
369 print(grid_bins)
371 if not value: # if None select full grid
372 value = [grid_bins[0], grid_bins[-2]]
373 elif isinstance(value, int): # if int select single value on grid
374 value = [float(value), float(value)]
375 elif isinstance(value, float): # if float select single value on grid
376 value = [value, value]
378 grid_selection[observable]["selection"] = value
379 grid_ids = digitize_coordinates_to_lattice(value, grid_bins)
380 grid_selection[observable]["grid_ids"] = grid_ids
381 grid_boundaries = get_boundary_coordinates_of_selection(
382 grid_bins, observable, grid_ids
383 )
384 grid_selection[observable]["bounds"] = grid_boundaries
386 if observable == "theta":
387 while grid_boundaries[1] < grid_boundaries[0]:
388 grid_boundaries[1] += 2 * np.pi
390 grid_selection[observable]["periodic_bounds"] = grid_boundaries
391 return grid_selection