1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| #include <cassert> #include <cmath> #include <iostream> #include <random> #include <algorithm> #include <limits> #include <functional> #include <map> #include <vector>
template <typename T> class discrete_random_variable { private: const std::vector<T> values_; const std::vector<double> cumulative_; mutable std::random_device rd_; mutable std::mt19937 gen_{rd_()}; mutable std::uniform_real_distribution<double> dis_{0.0, 1.0};
public: discrete_random_variable(const std::vector<T>& val, const std::vector<double>& prob) : values_(val), cumulative_(generate_cumulative(prob)) { assert(val.size() == prob.size()); assert(std::fabs(1.0 - cumulative_.back()) < std::numeric_limits<double>::epsilon()); }
T operator()() const { const double rand = dis_(gen_); const size_t idx = bsearch_last_not_greater_than(cumulative_.begin(), cumulative_.end(), rand); assert(idx < values_.size()); return values_[idx]; }
private: std::vector<double> generate_cumulative(const std::vector<double>& prob) { std::vector<double> cumulative; cumulative.reserve(prob.size() + 1); cumulative.emplace_back(0); std::transform(prob.begin(), prob.end(), std::back_inserter(cumulative), [&](const double p) { return p + cumulative.back(); } ); return cumulative; }
template <typename iter_t, typename value_t = typename std::iterator_traits<iter_t>::value_type, typename binpred_t = std::less<value_t>> size_t bsearch_last_not_greater_than(const iter_t begin, const iter_t end, const value_t target, binpred_t binpred = binpred_t()) const { iter_t first = begin, last = end; while (first < last) { iter_t mid = first + std::distance(first, last) / 2; if (not(binpred(target, *mid)) and (std::next(mid) == last or binpred(target, *(std::next(mid))))) { return std::distance(begin, mid); } else if (binpred(target, *mid)) { last = mid; } else { first = std::next(mid); } } return std::distance(begin, end); } };
int main() { std::vector<int> values{1, 2, 3, 4}; std::vector<double> probs{0.05, 0.25, 0.35, 0.35};
discrete_random_variable<int> drv{values, probs};
std::map<int, size_t> counter;
for (size_t i = 0; i != 400000; ++i) { int x = drv(); assert(std::find(values.begin(), values.end(), x) != values.end()); ++counter[x]; } for (auto pair : counter) { std::cout << pair.first << "[" << pair.second << "]" << ": \t"; for (size_t i = 0; i != pair.second / 2500; ++i) { std::cout << '='; } std::cout << std::endl; }
return 0; }
|