markrs/
config.rs

1//! This module handles the configuration I/O for the application.
2
3use std::fmt;
4use std::str::FromStr;
5
6use log::{error, info, warn};
7use serde::{Deserialize, Serialize};
8
9use crate::CONFIG;
10use crate::io::{does_config_exist, get_config_path, write_default_config};
11
12/// Represents the global configuration for the application.
13#[derive(Debug, Deserialize, Serialize, Default)]
14pub struct Config {
15    #[serde(default)]
16    pub lexer: LexerConfig,
17    #[serde(default)]
18    pub html: HtmlConfig,
19}
20
21/// Manages all configuration for tokenization
22#[derive(Debug, Deserialize, Serialize)]
23pub struct LexerConfig {
24    #[serde(default = "default_tab_size")]
25    pub tab_size: usize,
26}
27
28impl Default for LexerConfig {
29    fn default() -> Self {
30        LexerConfig { tab_size: 4 }
31    }
32}
33
34fn default_tab_size() -> usize {
35    4
36}
37
38/// Manages all configuration for HTML generation
39#[derive(Debug, Deserialize, Serialize)]
40pub struct HtmlConfig {
41    #[serde(default = "default_css")]
42    pub css_file: String,
43    #[serde(default)]
44    pub favicon_file: String,
45    #[serde(default)]
46    pub use_prism: bool,
47    #[serde(default = "default_prism_theme")]
48    pub prism_theme: String,
49    #[serde(default = "sanitize_by_default")]
50    pub sanitize_html: bool,
51}
52
53impl Default for HtmlConfig {
54    fn default() -> Self {
55        HtmlConfig {
56            css_file: default_css(),
57            favicon_file: "".to_string(),
58            use_prism: false,
59            prism_theme: default_prism_theme(),
60            sanitize_html: sanitize_by_default(),
61        }
62    }
63}
64
65/// Sets the default PrismJS theme to "vsc-dark-plus" in `config.toml`
66fn default_prism_theme() -> String {
67    "vsc-dark-plus".to_string()
68}
69
70/// Sets `sanitize_html` to true by default in `config.toml`
71fn sanitize_by_default() -> bool {
72    true
73}
74
75/// Sets the default CSS file to "default" in the case that the `css_file` field is omitted
76fn default_css() -> String {
77    "default".to_string()
78}
79
80impl Config {
81    /// Creates a new `Config` instance from the specified file path
82    ///
83    /// # Arguments
84    /// * `file_path` - The path to the configuration file. If no file path is provided as a CLI
85    ///   arg, it will check for a config file in the default config directory.
86    ///
87    /// # Returns
88    /// Returns a `Result` containing the `Config` instance if successful
89    pub fn from_file(file_path: &str) -> Result<Self, Error> {
90        // If the user provided a config file, try to load the config from it
91        if !file_path.is_empty() {
92            info!("Loading config from file: {}", file_path);
93            let contents = std::fs::read_to_string(file_path)?;
94
95            let config: Config = toml_edit::de::from_str(&contents)?;
96
97            validate_config(file_path, &contents, &config)?;
98
99            return Ok(config);
100        }
101
102        let config_path = get_config_path()?;
103
104        // If the user did not provide a config file, check if a config file exists in the config
105        // directory
106        if does_config_exist()? {
107            let contents = std::fs::read_to_string(&config_path)?;
108
109            let config: Config =
110                toml_edit::de::from_str(&contents).map_err(Error::TomlDeserialization)?;
111
112            validate_config(&config_path.to_string_lossy(), &contents, &config)?;
113
114            Ok(config)
115        } else {
116            warn!(
117                "No config file found, writing default config to: {}",
118                config_path.to_string_lossy()
119            );
120
121            let default_config = write_default_config()?;
122
123            Ok(default_config)
124        }
125    }
126}
127
128/// Validates the configuration by checking if the original config file matches the filled config
129///
130/// If the original config is missing fields, it updates the file with any missing fields
131fn validate_config(file_path: &str, contents: &str, config: &Config) -> Result<(), Error> {
132    let mut doc = toml_edit::DocumentMut::from_str(contents).map_err(Error::Toml)?;
133
134    let filled_doc = toml_edit::ser::to_document(config)?;
135
136    let mut config_needs_update = false;
137    let mut missing_fields = Vec::new();
138    for (section, values) in filled_doc.iter() {
139        let table = values.clone().into_table().unwrap_or_else(|_item| {
140            error!(
141                "Expected a table for field '{}', but found: {}",
142                section, values
143            );
144            panic!("Invalid configuration format for field '{}'", section);
145        });
146
147        for (sub_key, sub_value) in table.iter() {
148            if !doc.contains_key(section) {
149                doc[section] = filled_doc[section].clone();
150                config_needs_update = true;
151                missing_fields.push(section.to_string());
152            } else if !doc[section].is_table()
153                || !doc[section].as_table().unwrap().contains_key(sub_key)
154            {
155                doc[section][sub_key] = sub_value.clone();
156                config_needs_update = true;
157                missing_fields.push(format!("{}.{}", section, sub_key));
158            }
159        }
160    }
161
162    if config_needs_update {
163        warn!(
164            "Config is missing fields: {:?}, writing updated config to: {}",
165            missing_fields, file_path
166        );
167
168        // Formats the file with sections like `[lexer]` and `tab_size = 4`
169        // previously it would be `lexer = { tab_size = 4 }`
170        if !doc["lexer"].is_table() {
171            doc["lexer"] = doc["lexer"]
172                .clone()
173                .into_table()
174                .unwrap_or_else(|_item| {
175                    error!(
176                        "Expected 'lexer' to be a table, but found: {}",
177                        doc["lexer"]
178                    );
179                    panic!("Invalid configuration format for 'lexer'");
180                })
181                .into();
182        }
183        doc["lexer"].as_table_mut().unwrap().set_position(0);
184
185        if !doc["html"].is_table() {
186            doc["html"] = doc["html"]
187                .clone()
188                .into_table()
189                .unwrap_or_else(|_item| {
190                    error!("Expected 'html' to be a table, but found: {}", doc["html"]);
191                    panic!("Invalid configuration format for 'html'");
192                })
193                .into();
194        }
195        doc["html"].as_table_mut().unwrap().sort_values();
196
197        std::fs::write(file_path, doc.to_string())?
198    }
199
200    Ok(())
201}
202
203/// Initializes the global configuration from the specified file path
204///
205/// # Arguments
206/// * `config_path` - The path to the configuration file.
207///
208/// # Returns
209/// Returns a `Result` indicating success or failure. If successful, a global `CONFIG` has been
210/// initialized.
211pub fn init_config(config_path: &str) -> Result<(), Error> {
212    CONFIG.get_or_init(|| {
213        Config::from_file(config_path).unwrap_or_else(|err| {
214            error!("Failed to load config: {err}");
215            std::process::exit(1);
216        })
217    });
218    Ok(())
219}
220
221#[derive(Debug)]
222pub enum Error {
223    Io(std::io::Error),
224    Toml(toml_edit::TomlError),
225    TomlSerialization(toml_edit::ser::Error),
226    TomlDeserialization(toml_edit::de::Error),
227}
228
229// Display
230impl fmt::Display for Error {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        match self {
233            Error::Io(e) => write!(f, "I/O error: {e}"),
234            Error::Toml(e) => write!(f, "TOML error: {e}"),
235            Error::TomlSerialization(e) => write!(f, "TOML serialization error: {e}"),
236            Error::TomlDeserialization(e) => write!(f, "TOML deserialization error: {e}"),
237        }
238    }
239}
240
241impl From<std::io::Error> for Error {
242    fn from(err: std::io::Error) -> Self {
243        Error::Io(err)
244    }
245}
246
247impl From<toml_edit::TomlError> for Error {
248    fn from(err: toml_edit::TomlError) -> Self {
249        Error::Toml(err)
250    }
251}
252
253impl From<toml_edit::ser::Error> for Error {
254    fn from(err: toml_edit::ser::Error) -> Self {
255        Error::TomlSerialization(err)
256    }
257}
258
259impl From<toml_edit::de::Error> for Error {
260    fn from(err: toml_edit::de::Error) -> Self {
261        Error::TomlDeserialization(err)
262    }
263}
264
265impl std::error::Error for Error {
266    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
267        match self {
268            Error::Io(e) => Some(e),
269            Error::Toml(e) => Some(e),
270            Error::TomlSerialization(e) => Some(e),
271            Error::TomlDeserialization(e) => Some(e),
272        }
273    }
274}