1
use thiserror::Error;
2
use ustr::Ustr;
3

            
4
#[derive(Clone, Debug, PartialEq, Eq, Error)]
5
pub enum CombinatoricsError {
6
    #[error("The operation is not defined for the given input: {0}")]
7
    NotDefined(Ustr),
8
    #[error("The result is too large to fit into the return type")]
9
    Overflow,
10
}
11

            
12
impl CombinatoricsError {
13
    pub fn not_defined(input: impl Into<Ustr>) -> Self {
14
        Self::NotDefined(input.into())
15
    }
16
}
17

            
18
/// Count *combinations* - the number of ways to pick `n_choose` items from `n_total`,
19
/// where order does not matter.
20
///
21
/// # Formula
22
/// C(n, r) = n! / (r! * (n-r)!)
23
///
24
/// Not defined for r > n.
25
pub fn count_combinations(n_total: u64, n_choose: u64) -> Result<u64, CombinatoricsError> {
26
    if n_choose > n_total {
27
        return Err(CombinatoricsError::not_defined(
28
            "n_choose must be <= n_total",
29
        ));
30
    }
31

            
32
    // Use symmetry C(n, k) == C(n, n-k) to make the loop smaller
33
    let k = n_choose.min(n_total - n_choose);
34

            
35
    // Repeatedly multiply / divide as factors get big fast;
36
    // return None if we overflow anyway
37
    (1u64..=k).try_fold(1u64, |acc, val| {
38
        n_total
39
            .checked_sub(val)
40
            .ok_or(CombinatoricsError::Overflow)? // n_total - val
41
            .checked_add(1u64)
42
            .ok_or(CombinatoricsError::Overflow)? // + 1
43
            .checked_mul(acc)
44
            .ok_or(CombinatoricsError::Overflow)? // * acc
45
            .checked_div(val)
46
            .ok_or(CombinatoricsError::Overflow) // / val
47
    })
48
}
49

            
50
/// Count *permutations* - the number of ways to pick `n_choose` items from `n_total`,
51
/// where order matters.
52
///
53
/// # Formula
54
/// P(n, r) = n! / (n-r)!
55
///
56
/// Not defined for r > n.
57
#[allow(dead_code)]
58
pub fn count_permutations(n_total: u64, n_choose: u64) -> Result<u64, CombinatoricsError> {
59
    if n_choose > n_total {
60
        return Err(CombinatoricsError::not_defined(
61
            "n_choose must be <= n_total",
62
        ));
63
    }
64

            
65
    let start = n_total
66
        .checked_sub(n_choose)
67
        .ok_or(CombinatoricsError::Overflow)?
68
        .checked_add(1u64)
69
        .ok_or(CombinatoricsError::Overflow)?;
70
    (start..=n_total).try_fold(1u64, |acc, val| {
71
        acc.checked_mul(val).ok_or(CombinatoricsError::Overflow)
72
    })
73
}