Source code for physped.visualization.plot_potential_at_slow_index

import logging
from pathlib import Path
from typing import List

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import DictConfig

from physped.core.parametrize_potential import (
    calculate_position_based_emperic_potential,
    extract_submatrix,
)
from physped.core.piecewise_potential import PiecewisePotential
from physped.visualization.plot_utils import (
    apply_polar_plot_style,
    apply_xy_plot_style,
    highlight_velocity_selection,
    plot_polar_velocity_grid,
)

log = logging.getLogger(__name__)


[docs] def plot_potential_at_slow_index( config: DictConfig, slow_indices: List, piecewise_potential: PiecewisePotential, ): params = config.params traj_plot_params = params.trajectory_plot fig = plt.figure(layout="constrained") fig.set_size_inches(traj_plot_params.figsize) spec = mpl.gridspec.GridSpec( ncols=2, nrows=1, width_ratios=traj_plot_params.width_ratios, wspace=0.1, hspace=0.1, figure=fig, ) ax = fig.add_subplot(spec[0]) plot_params = config.params.force_field_plot cmap = "YlOrRd" xbin_middle = ( config.params.grid.bins.x[1:] + config.params.grid.bins.x[:-1] ) / 2 ybin_middle = ( config.params.grid.bins.y[1:] + config.params.grid.bins.y[:-1] ) / 2 X, Y = np.meshgrid(xbin_middle, ybin_middle, indexing="ij") slicing_indices = [ [0, len(config.params.grid.bins.x) - 1], [0, len(config.params.grid.bins.y) - 1], [slow_indices[2], slow_indices[2] + 1], [slow_indices[3], slow_indices[3] + 1], [slow_indices[4], slow_indices[4] + 1], ] slow_subhistogram = extract_submatrix( piecewise_potential.histogram_slow, slicing_indices ) position_based_emperic_potential = ( calculate_position_based_emperic_potential(slow_subhistogram, config) ) # matrix_to_plot = get_position_based_emperic_potential_from_state(config, # slicing_indices, piecewise_potential) # X_indx = get_index_of_state(state, piecewise_potential) subparameterrization = extract_submatrix( piecewise_potential.parametrization, slicing_indices ) center_x = subparameterrization[:, :, 0, 0, 0, 0, 0] center_y = subparameterrization[:, :, 0, 0, 0, 1, 0] curvature_x = subparameterrization[:, :, 0, 0, 0, 0, 1] curvature_y = subparameterrization[:, :, 0, 0, 0, 1, 1] # center_u = sliced_fit_parameters[:, :, 0, 0, 0, 4] # center_v = sliced_fit_parameters[:, :, 0, 0, 0, 6] # sliced_curvature_x = get_slice_of_multidimensional_matrix( # piecewise_potential.curvature_x, slices) # sliced_curvature_y = get_slice_of_multidimensional_matrix( # piecewise_potential.curvature_y, slices) curvature_scaling = 1 curv_x = (curvature_x * (X - center_x)) / curvature_scaling curv_y = (curvature_y * (Y - center_y)) / curvature_scaling # fig, ax = plt.subplots() scale = plot_params.scale sparseness = plot_params.sparseness minimum_threshold = 1 # sliced_histogram = extract_submatrix(piecewise_potential.histogram_slow, # slicing_indices) plot_curv_x = np.where( slow_subhistogram[:, :, 0, 0, 0] < minimum_threshold, np.nan, curv_x ) plot_curv_y = np.where( slow_subhistogram[:, :, 0, 0, 0] < minimum_threshold, np.nan, curv_y ) ax.pcolormesh( X, Y, position_based_emperic_potential, cmap=cmap, shading="auto" ) # , norm=norm) # ax = plot_colorbar(ax, cs) ax.quiver( X[::sparseness, ::sparseness], Y[::sparseness, ::sparseness], -plot_curv_x[::sparseness, ::sparseness], -plot_curv_y[::sparseness, ::sparseness], scale=scale, pivot="mid", width=0.0015, # labelpos="E", # label="Vectors: $f^{\\prime }(x)= # -{\\frac {x-\\mu }{\\sigma ^{2}}}f(x)$", ) # ax = plot_quiverkey(ax, q) ax = apply_xy_plot_style(ax, params) # ax.set_aspect("equal") # ax.set_xlim(config.params.default_xlims) # ax.set_ylim(config.params.default_ylims) ax2 = fig.add_subplot(spec[1], polar=True) ax2 = apply_polar_plot_style(ax2, params) ax2 = plot_polar_velocity_grid(ax2, params.grid) # ax2 = plot_polar_labels(ax2, params.grid) # if plot_params.plot_trajs: # ax2 = plot_velocity_trajectories_in_polar_coordinates(ax2, plot_trajs, # alpha=plot_params.alpha, traj_type="f") # ax2.set_ylim(params.grid.bins.r[0], params.grid.bins.r[-1]) ax2.grid(False) # ax2.set_title(plot_params.title.velocity, y=1) # if plot_params.highlight_selection: # ax1 = highlight_position_selection(ax1, params) ax2 = highlight_velocity_selection(ax2, params) folderpath = Path.cwd() / "figures" folderpath.mkdir(parents=True, exist_ok=True) filepath = folderpath / "potential_plot_at_slow_index.pdf" plt.savefig(filepath, bbox_inches="tight")
# log.info("Saving plot of the grid to %s.", # filepath.relative_to(config.root_dir)) # plot_trajectories_on_field = False # if plot_trajectories_on_field: # ax.plot(traj.xf, traj.yf, ms=10, zorder=20, c = 'C0', lw = 0.5) # ax.plot(traj['xs'], traj['ys'], ms=10, zorder=20, linestyle = 'dashed', # c = 'C1', lw = 0.5)