lacoctelera/authentication/
token_auth.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
// Copyright 2024 Felipe Torres González
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

//! Utilities for managing access tokens of the API.

use crate::domain::{ClientId, DataDomainError, ServerError};
use argon2::{
    password_hash::SaltString,
    {Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version},
};
use chrono::{Local, TimeDelta};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use secrecy::{ExposeSecret, SecretString};
use sqlx::{Executor, MySql, MySqlPool, Transaction};
use std::{error::Error, str::FromStr};
use tracing::{debug, error, info};

/// Check if a given token matches the hash stored in the DB.
///
/// # Description
///
/// This function receives two values: the stored hash of the token in the DB, and the token used by the client in
/// a request to the API. The function hashes the given token and compares both. If both match, `Ok(())` is returned,
/// and an `Err(InvalidAccessCredentials)` otherwise.
#[tracing::instrument(name = "Validate credentials", skip(expected_token, given_token))]
pub fn verify_token(
    expected_token: SecretString,
    given_token: SecretString,
) -> Result<(), DataDomainError> {
    let expected_token_hash = PasswordHash::new(expected_token.expose_secret()).map_err(|e| {
        error!("Couldn't hash the given password: {e}");
        DataDomainError::InvalidAccessCredentials
    })?;

    match Argon2::default()
        .verify_password(given_token.expose_secret().as_bytes(), &expected_token_hash)
    {
        Ok(_) => {
            debug!("The given token matches the stored hash");
            Ok(())
        }
        Err(_) => {
            debug!("The given token does not match the stored hash");
            Err(DataDomainError::InvalidAccessCredentials)
        }
    }
}

/// Generate a token of 25 alphanumeric characters.
pub fn generate_token() -> String {
    let mut rng = thread_rng();
    std::iter::repeat_with(|| rng.sample(Alphanumeric))
        .map(char::from)
        .take(25)
        .collect()
}

/// Hash a plain token using Argon2.
pub fn generate_new_token_hash(plain_token: SecretString) -> Result<SecretString, anyhow::Error> {
    let salt = SaltString::generate(&mut rand::thread_rng());
    let token_hash = Argon2::new(
        Algorithm::Argon2id,
        Version::V0x13,
        Params::new(15000, 2, 1, None).unwrap(),
    )
    .hash_password(plain_token.expose_secret().as_bytes(), &salt)
    .map_err(|_| ServerError::DbError)?
    .to_string();

    Ok(SecretString::from(token_hash))
}

/// Store a validation token in the DB.
#[tracing::instrument(skip(transaction, token))]
pub async fn store_validation_token(
    transaction: &mut Transaction<'static, MySql>,
    token: &SecretString,
    expiry: TimeDelta,
    client_id: &ClientId,
) -> Result<(), ServerError> {
    let query = sqlx::query!(
        r#"
        INSERT INTO ApiToken
        (created, api_token, valid_until, client_id)
        VALUES(current_timestamp(), ?, ?, ?);
        "#,
        token.expose_secret(),
        Local::now() + expiry,
        client_id.to_string(),
    );

    transaction.execute(query).await.map_err(|e| {
        error!("{e}");
        ServerError::DbError
    })?;

    Ok(())
}

/// Delete a token that will be no longer used.
#[tracing::instrument(skip(pool, token))]
pub async fn delete_token(pool: &MySqlPool, token: SecretString) -> Result<(), ServerError> {
    let query = sqlx::query!(
        "DELETE FROM ApiToken WHERE api_token = ?",
        token.expose_secret()
    );

    pool.execute(query).await.map_err(|e| {
        error!("{e}");
        ServerError::DbError
    })?;

    Ok(())
}

/// Check if the client hash access to the restricted API's endpoints.
///
/// # Description
///
/// Given a client access token, the stored hash of the token is retrieved from the database and compared. If the
/// comparison is positive, it is checked if the client is enabled.
pub async fn check_access(pool: &MySqlPool, token: &SecretString) -> Result<(), Box<dyn Error>> {
    // Let's split the token to get the client's ID and the token itself.
    let token_split = token.expose_secret().split(':').collect::<Vec<&str>>();
    let client_id = token_split[0];
    let token = SecretString::from(token_split[1]);
    // First, retrieve the credentials for the client using the email.
    let query = sqlx::query!(
        r#"
        SELECT at.api_token, at.valid_until, au.enabled
        FROM ApiUser au natural join ApiToken at
        WHERE au.id = ?
        "#,
        client_id.to_string()
    )
    .fetch_optional(pool)
    .await
    .map_err(|e| {
        error!("{e}");
        Box::new(ServerError::DbError)
    })?;

    let (token_saved, valid_until, enabled) = match query {
        Some(record) => (
            SecretString::from(record.api_token),
            record.valid_until,
            record.enabled,
        ),
        None => {
            info!("The given client ID ({client_id}) does not exist in the DB");
            return Err(Box::new(DataDomainError::InvalidId));
        }
    };

    debug!(
        "The client exists in the DB. Proceeding to compare the given token with the stored hash"
    );

    // First, check if the given pair client-token matches the saved one. This avoid giving information about disabled
    // accounts or expired tokens to people that has no access to the API.
    verify_token(token_saved, token).map_err(Box::new)?;
    debug!("The token is valid and registered to the client");

    // Second, check if the account is actually enabled.
    if enabled.unwrap_or_default() > 0 {
        debug!("The client's account is enabled");
        // Finally, check that the token is not expired.
        if valid_until.date_naive() - Local::now().date_naive() < TimeDelta::zero() {
            debug!("The client's token is expired");
            Err(Box::new(DataDomainError::ExpiredAccess))
        } else {
            debug!("The token is valid and not expired");
            Ok(())
        }
    } else {
        debug!("The account is disabled");
        Err(Box::new(DataDomainError::AccountDisabled))
    }
}

/// Enable an API client account.
#[tracing::instrument(skip(pool))]
pub async fn enable_client(pool: &MySqlPool, client_id: &ClientId) -> Result<(), ServerError> {
    let query = sqlx::query!(
        r#"
    UPDATE ApiUser SET enabled = TRUE
    WHERE id = ?;
    "#,
        client_id.to_string()
    );

    pool.execute(query).await.map_err(|e| {
        error!("{e}");
        ServerError::DbError
    })?;

    Ok(())
}

/// Check if the user attempted to or is registered already in the DB.
#[tracing::instrument(skip(pool))]
pub async fn check_existing_user(
    pool: &MySqlPool,
    email: &str,
) -> Result<ClientId, Box<dyn Error>> {
    let existing_id = sqlx::query!("SELECT id FROM ApiUser WHERE email = ?", email)
        .fetch_optional(pool)
        .await
        .map_err(|e| {
            error!("{e}");
            ServerError::DbError
        })?;

    match existing_id {
        Some(record) => Ok(ClientId::from_str(&record.id).unwrap()),
        None => Err(Box::new(DataDomainError::InvalidEmail)),
    }
}

// Validate client's account
#[tracing::instrument(skip(transaction))]
pub async fn validate_client_account(
    transaction: &mut Transaction<'static, MySql>,
    id: &ClientId,
) -> Result<(), ServerError> {
    let query = sqlx::query!(
        r#"
        UPDATE ApiUser
        SET validated = TRUE
        WHERE id = ?;
        "#,
        id.to_string()
    );

    transaction.execute(query).await.map_err(|e| {
        error!("Error found while updating ApiUser's entry for {id}: {e}");
        ServerError::DbError
    })?;

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use rstest::*;
    use secrecy::SecretString;

    #[rstest]
    fn equal_token_hash_match() {
        let token = SecretString::from(generate_token());
        let token_hash =
            generate_new_token_hash(token.clone()).expect("Failed to generate token hash");
        //let token2_hash = generate_new_token_hash(token).expect("Failed to generate token hash");
        assert!(verify_token(token_hash, token).is_ok())
    }

    #[rstest]
    fn different_token_hash_does_not_match() {
        let token = SecretString::from(generate_token());
        let token_hash = generate_new_token_hash(token).expect("Failed to generate token hash");
        let token = SecretString::from(generate_token());
        let token2_hash = generate_new_token_hash(token).expect("Failed to generate token hash");
        assert!(verify_token(token_hash, token2_hash).is_err())
    }
}