Coverage for physped/utils/config_utils.py: 84%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-01 09:28 +0000

1"""Module to define utility functions for the configuration.""" 

2 

3import logging 

4import os 

5from pathlib import Path 

6from pprint import pformat 

7 

8import matplotlib.pyplot as plt 

9import numpy as np 

10from hydra import compose, initialize 

11from omegaconf import DictConfig, OmegaConf 

12 

13ROOT_DIR = Path(os.path.dirname(os.path.abspath(__file__))).parent 

14 

15log = logging.getLogger(__name__) 

16 

17 

18def log_configuration(config: dict) -> None: 

19 log.info( 

20 ( 

21 "\n* Environment name: %s\n\n" 

22 "* Working directory \n%s\n\n" 

23 "* Project root \n%s\n\n" 

24 "* Modeling parameters: \n%s" 

25 ), 

26 config.params.env_name, 

27 Path.cwd(), 

28 config.root_dir, 

29 pformat( 

30 OmegaConf.to_container(config.params.model, resolve=True), depth=1 

31 ), 

32 ) 

33 

34 

35def apply_periodic_conditions_to_the_angle_theta(theta: float): 

36 """ 

37 Apply periodic conditions to the angle theta. 

38 

39 Args: 

40 theta (float): The angle theta. 

41 

42 Returns: 

43 float: The angle theta after applying the periodic conditions. 

44 """ 

45 theta += np.pi 

46 return theta % (2 * np.pi) - np.pi 

47 

48 

49def create_grid_name(grid_list: list): 

50 grid_list = [f"-{int(i*10)}" for i in grid_list] 

51 grid_name = "".join(grid_list) 

52 return grid_name 

53 

54 

55def set_plot_style(config: DictConfig, use_latex: bool = False) -> None: 

56 """Function to set the plot style. 

57 

58 Args: 

59 use_latex: Whether to use LaTeX for the plot style or not. 

60 Defaults to False. 

61 """ 

62 get_style = {True: "science", False: "science_no_latex"} 

63 style = get_style[use_latex] 

64 plt.style.use(Path(config.root_dir) / f"conf/{style}.mplstyle") 

65 

66 

67def register_new_resolvers(replace=False): 

68 OmegaConf.register_new_resolver( 

69 "get_root_dir", lambda: ROOT_DIR, replace=replace 

70 ) 

71 OmegaConf.register_new_resolver( 

72 "parse_pi", lambda a: a * np.pi, replace=replace 

73 ) 

74 OmegaConf.register_new_resolver( 

75 "generate_linear_bins", 

76 lambda min, max, step: np.arange(min, max + 0.01, step), 

77 replace=replace, 

78 ) 

79 OmegaConf.register_new_resolver( 

80 "generate_angular_bins", 

81 lambda min, segments: np.linspace(min, min + 2 * np.pi, segments + 1), 

82 replace=replace, 

83 ) 

84 OmegaConf.register_new_resolver( 

85 "cast_numpy_array", np.array, replace=replace 

86 ) 

87 OmegaConf.register_new_resolver( 

88 "apply_periodic_conditions_to_the_angle_theta", 

89 apply_periodic_conditions_to_the_angle_theta, 

90 replace=replace, 

91 ) 

92 OmegaConf.register_new_resolver( 

93 "inv_prop", lambda x: 1 / x, replace=replace 

94 ) 

95 OmegaConf.register_new_resolver( 

96 "create_grid_name", create_grid_name, replace=replace 

97 ) 

98 OmegaConf.register_new_resolver( 

99 "set_plot_style", set_plot_style, replace=replace 

100 ) 

101 

102 

103def initialize_hydra_config(env_name: str) -> DictConfig: 

104 """Function to initialize the Hydra configuration. 

105 

106 Args: 

107 env_name: The name of the environment. 

108 For example: 'narrow_corridor', 'intersecting_paths', 

109 'asdz_pf12', 'asdz_pf34', 'utrecht_pf5'. 

110 

111 Returns: 

112 The Hydra configuration. 

113 """ 

114 with initialize(version_base=None, config_path="../conf"): 

115 config = compose( 

116 config_name="config", 

117 return_hydra_config=True, 

118 overrides=[ 

119 f"params={env_name}", 

120 ], 

121 ) 

122 register_new_resolvers(replace=True) 

123 return config