Browse Source
Add discard/keep modifier to dice roll, improve dice roll tests
archive/feature/daily-titles-log
Add discard/keep modifier to dice roll, improve dice roll tests
archive/feature/daily-titles-log
5 changed files with 362 additions and 124 deletions
-
5devpotato_bot/commands/roll.py
-
93devpotato_bot/dice_parser.py
-
47devpotato_bot/quickselect.py
-
257tests/test_dice_parser.py
-
84tests/test_quickselect.py
@ -0,0 +1,47 @@ |
|||
import operator |
|||
import random |
|||
|
|||
|
|||
def partition(items, left, right, pivot_index, compare): |
|||
"""Groups items between left and right index into two parts: |
|||
those less than element at pivot_index, and those greater than or equal to it |
|||
""" |
|||
assert 0 <= left <= pivot_index <= right < len(items) |
|||
if left == right: |
|||
return pivot_index |
|||
|
|||
pivot_value = items[pivot_index] |
|||
items[right], items[pivot_index] = items[pivot_index], items[right] |
|||
pivot_index = left |
|||
for i in range(left, right): |
|||
if compare(items[i], pivot_value): |
|||
items[pivot_index], items[i] = items[i], items[pivot_index] |
|||
pivot_index += 1 |
|||
items[right], items[pivot_index] = items[pivot_index], items[right] |
|||
return pivot_index |
|||
|
|||
|
|||
def select(items, k, compare=operator.lt): |
|||
"""Reorders items in-place so that first k items are the lowest (largest) ones |
|||
|
|||
A custom compare function can be supplied to customize comparison of the items. |
|||
The function uses quickselect algorithm with minor modifications.""" |
|||
if not items or k == 0 or k == len(items): |
|||
return |
|||
if k < 0: |
|||
raise ValueError('k should not be less than zero') |
|||
if k > len(items): |
|||
raise ValueError('k should not be greater than item count') |
|||
left = 0 |
|||
right = len(items) - 1 |
|||
while True: |
|||
if left == right: |
|||
break |
|||
pivot_index = random.randint(left, right) |
|||
pivot_index = partition(items, left, right, pivot_index, compare) |
|||
if (k - 1) == pivot_index: |
|||
break |
|||
elif (k - 1) < pivot_index: |
|||
right = pivot_index - 1 |
|||
else: |
|||
left = pivot_index + 1 |
@ -1,147 +1,188 @@ |
|||
import itertools |
|||
import unittest |
|||
from unittest import mock |
|||
|
|||
from devpotato_bot.dice_parser import Dice, ParseError, ValueRangeError |
|||
from devpotato_bot.dice_parser import Dice, ParseError, ValueRangeError, ResultsKeepStrategy, RollResult |
|||
|
|||
|
|||
class DiceParserTest(unittest.TestCase): |
|||
@staticmethod |
|||
def _DiceVars(dice): |
|||
return dice.rolls, dice.sides, dice.modifier |
|||
|
|||
def test_basic_notation(self): |
|||
for roll_str, expected in ( |
|||
('1d6', (1, 6, 0)), |
|||
('d6', (1, 6, 0)), |
|||
('d5', (1, 5, 0)), |
|||
('5d1', (5, 1, 0)) |
|||
): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected) |
|||
class ResultsKeepStrategyMock: |
|||
def __init__(self, total, discard): |
|||
self._expected_total = total |
|||
self._discarded_default = discard |
|||
|
|||
for roll_str in ('d', '1d', '-1d', '-1d6'): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertRaises(ParseError, Dice.parse, roll_str) |
|||
def get_discarded_default(self, _): |
|||
return self._discarded_default |
|||
|
|||
def apply(self, results) -> int: |
|||
if ResultsKeepStrategyMock.DEFAULT is self: |
|||
print('apply') |
|||
# ensure that actual strategy gets own copy of items |
|||
results.clear() |
|||
return self._expected_total |
|||
|
|||
for roll_str in ('0d6', '1d0', 'd0', '0d0'): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertRaises(ValueRangeError, Dice.parse, roll_str) |
|||
|
|||
def test_modifier(self): |
|||
for roll_str, expected in ( |
|||
('d6-5', (1, 6, -5)), |
|||
('d6+5', (1, 6, 5)), |
|||
('d6+0', (1, 6, 0)), |
|||
('2d10-1', (2, 10, -1)) |
|||
): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected) |
|||
|
|||
for roll_str in ('d+6', '1d-6', '-6d-1', 'd6+', 'd6-', 'd6-+6', 'd6+-6', 'd6-+-6', 'd6+d6'): |
|||
ResultsKeepStrategyMock.DEFAULT = ResultsKeepStrategyMock(-1, object()) |
|||
|
|||
|
|||
class DiceTest(unittest.TestCase): |
|||
def test_invalid_patterns(self): |
|||
for roll_str in ( |
|||
'd', '1d', '-1d', '-1d6', 'd6+', 'd6-', |
|||
'd+6', '1d-6', '-6d-1', 'd6-+6', 'd6+-6', 'd6-+-6', 'd6+d6', |
|||
'%d', '1%d', '-%1d', '-1%d6', '-1d%', 'd%%', '2d%%', |
|||
'd6L', 'd6H', 'dH1', 'dL1', 'd6L-1', 'd6H-1' |
|||
): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertRaises(ParseError, Dice.parse, roll_str) |
|||
|
|||
def test_percentile_dice(self): |
|||
for roll_str, expected in ( |
|||
('d%', (1, 100, 0)), |
|||
('d%-5', (1, 100, -5)), |
|||
('d%+5', (1, 100, 5)), |
|||
('d%+0', (1, 100, 0)), |
|||
('2d%-1', (2, 100, -1)), |
|||
('2d%+1', (2, 100, 1)) |
|||
): |
|||
for roll_str in ('0d6', '1d0', 'd0', '0d0', '0d%'): |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected) |
|||
self.assertRaises(ValueRangeError, Dice.parse, roll_str) |
|||
|
|||
for roll_str in ('%d', '1%d', '-%1d', '-1%d6', '-1d%', 'd%%', '2d%%'): |
|||
def test_valid_patterns(self): |
|||
roll_counts = [('', 1), ('1', 1), ('2', 2), ('10', 10), ('05', 5)] |
|||
side_counts = [ |
|||
('d6', 6), ('d20', 20), |
|||
('d%', 100), ('d100', 100), |
|||
('d1', 1), ('d2', 2), ('d03', 3) |
|||
] |
|||
modifiers = [ |
|||
('', 0), ('+0', 0), |
|||
('+1', 1), ('-1', -1), |
|||
('+05', 5), ('-05', -5) |
|||
] |
|||
discards = itertools.chain( |
|||
[('', (0, False, False))], |
|||
( |
|||
(f'{prefix}{t}{count}', (count, prefix != '-', t == 'L')) |
|||
for prefix, t, count |
|||
in itertools.product(['', '+', '-'], 'LH', [0, 1, 2, 10, 20]) |
|||
) |
|||
) |
|||
for i in itertools.product(roll_counts, side_counts, discards, modifiers): |
|||
roll_str, expected_vars = zip(*i) |
|||
roll_str = ''.join(roll_str) |
|||
with self.subTest(roll_str=roll_str): |
|||
self.assertRaises(ParseError, Dice.parse, roll_str) |
|||
rolls, sides, discard_vars, modifier = expected_vars |
|||
d = Dice.parse(roll_str) |
|||
self.assertEqual(d.sides, sides) |
|||
self.assertEqual(d.rolls, rolls) |
|||
self.assertEqual(d.modifier, modifier) |
|||
|
|||
self.assertRaises(ValueRangeError, Dice.parse, '0d%') |
|||
discard_count, discard_keep, discard_lowest = discard_vars |
|||
self.assertEqual(d.discard_strategy.count, min(discard_count, rolls)) |
|||
self.assertEqual(d.discard_strategy.keep, discard_keep) |
|||
self.assertEqual(d.discard_strategy.lowest, discard_lowest) |
|||
|
|||
def test_init(self): |
|||
for rolls, sides, modifier in ( |
|||
(1, 1, 0), |
|||
(1, 6, 0), |
|||
(1, 10, 0), |
|||
(100, 6, 0), |
|||
(100, 1, 0), |
|||
(1, 120, 0), |
|||
(100, 120, 0) |
|||
): |
|||
with self.subTest(rolls=rolls, sides=sides): |
|||
d = Dice(rolls, sides) |
|||
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier)) |
|||
|
|||
with self.subTest(rolls=rolls, sides=sides, modifier=None): |
|||
d = Dice(rolls, sides, None) |
|||
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier)) |
|||
|
|||
with self.subTest(rolls=rolls, sides=sides, modifier=modifier): |
|||
d = Dice(rolls, sides, modifier) |
|||
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier)) |
|||
|
|||
for rolls, sides, modifier in ( |
|||
(1, 1, -1), |
|||
(1, 1, 1), |
|||
(1, 6, 6), |
|||
(1, 6, -6) |
|||
): |
|||
with self.subTest(rolls=rolls, sides=sides, modifier=modifier): |
|||
d = Dice(rolls, sides, modifier) |
|||
self.assertEqual(d.rolls, rolls) |
|||
self.assertEqual(d.sides, sides) |
|||
self.assertEqual(Dice(1, 6).modifier, 0) |
|||
self.assertEqual(Dice(1, 6, modifier=None).modifier, 0) |
|||
self.assertEqual(Dice(1, 6).discard_strategy, ResultsKeepStrategy.DEFAULT) |
|||
|
|||
for rolls, sides in ( |
|||
(0, 0), |
|||
(0, 6), |
|||
(1, 0), |
|||
|
|||
(-1, 6), |
|||
(-1, -6), |
|||
(1, -6), |
|||
|
|||
(1, 121), |
|||
(100, 121), |
|||
(120, 120), |
|||
(120, 6), |
|||
(120, 100) |
|||
): |
|||
invalid_roll_counts = itertools.product( |
|||
(-1, 0, Dice.ROLL_LIMIT + 1), |
|||
(-1, 0, 6, 100, Dice.BIGGEST_DICE, Dice.BIGGEST_DICE + 1) |
|||
) |
|||
invalid_sides = itertools.product( |
|||
(1, 10, Dice.ROLL_LIMIT), |
|||
(-1, 0, Dice.BIGGEST_DICE + 1) |
|||
) |
|||
for rolls, sides in itertools.chain(invalid_roll_counts, invalid_sides): |
|||
with self.subTest(rolls=rolls, sides=sides): |
|||
self.assertRaises(ValueRangeError, Dice, rolls, sides) |
|||
|
|||
@mock.patch('random.randint', return_value=1) |
|||
@mock.patch('random.randint') |
|||
def test_get_result(self, randint): |
|||
for roll_count, sides, modifier in ( |
|||
(1, 2, 0), (3, 6, 0), (3, 6, -1), (3, 6, +1), (5, 10, -10), (5, 10, +10), |
|||
(1, 100, 0), (5, 100, -5), (10, 100, 0), (2, 100, 10) |
|||
discard_strategy = ResultsKeepStrategyMock(1337, discard=object()) |
|||
for (roll_count, sides), modifier in itertools.product( |
|||
[(1, 1), (1, 6), (10, 20)], |
|||
[0, -1, +1] |
|||
): |
|||
with self.subTest(rolls=roll_count, sides=sides, modifier=modifier): |
|||
d = Dice(roll_count, sides, modifier) |
|||
d = Dice(roll_count, sides, modifier=modifier, discard_strategy=discard_strategy) |
|||
expected_calls = [mock.call(1, sides)] * roll_count |
|||
expected_total = roll_count + modifier |
|||
expected_rolls = [1] * roll_count |
|||
for item_limit in ( |
|||
None, 0, 1, roll_count - 1, roll_count, roll_count + 1, roll_count + 20 |
|||
None, 0, 1, roll_count - 1, roll_count, roll_count + 1, roll_count + 20 |
|||
): |
|||
randint.side_effect = range(roll_count) |
|||
with self.subTest(item_limit=item_limit): |
|||
roll_total, single_rolls, was_limited = d.get_result(item_limit) |
|||
roll_total, single_rolls, was_limited = d.get_results(item_limit) |
|||
randint.assert_has_calls(expected_calls) |
|||
self.assertEqual(randint.call_count, roll_count) |
|||
self.assertEqual(roll_total, expected_total) |
|||
|
|||
limit_is_set = item_limit is not None and roll_count > item_limit |
|||
self.assertEqual(was_limited, limit_is_set) |
|||
roll_subset = expected_rolls[:(item_limit if limit_is_set else roll_count)] |
|||
self.assertSequenceEqual(single_rolls, roll_subset) |
|||
self.assertEqual(roll_total, discard_strategy._expected_total + modifier) |
|||
|
|||
should_limit = item_limit is not None and roll_count > item_limit |
|||
self.assertEqual(was_limited, should_limit) |
|||
|
|||
expected_rolls = range(item_limit if should_limit else roll_count) |
|||
self.assertTrue(all( |
|||
r.value == e |
|||
for r, e in itertools.zip_longest(single_rolls, expected_rolls) |
|||
)) |
|||
self.assertTrue(all( |
|||
r.is_discarded is discard_strategy._discarded_default |
|||
for r in single_rolls |
|||
)) |
|||
randint.reset_mock() |
|||
|
|||
|
|||
for item_limit in (-1, -roll_count, -roll_count-1): |
|||
with self.subTest(item_limit=item_limit): |
|||
self.assertRaises(ValueError, d.get_result, item_limit) |
|||
self.assertRaises(ValueError, d.get_results, item_limit) |
|||
randint.assert_not_called() |
|||
randint.reset_mock() |
|||
|
|||
|
|||
class ResultsKeepStrategyTest(unittest.TestCase): |
|||
@classmethod |
|||
def setUpClass(cls): |
|||
ordered = tuple(range(1, 11)) |
|||
shuffled = list(ordered) |
|||
import random |
|||
random.seed(1024) |
|||
random.shuffle(shuffled) |
|||
random_items = tuple(random.randint(1, 10) for _ in range(10)) |
|||
cls.items_pool = ( |
|||
('reversed', tuple(reversed(ordered)), ordered), |
|||
('shuffled', tuple(shuffled), ordered), |
|||
('random', random_items, tuple(sorted(random_items))) |
|||
) |
|||
|
|||
def test_init(self): |
|||
for count, keep, use_lowest in itertools.product( |
|||
(-1, -2), (True, False), (True, False) |
|||
): |
|||
self.assertRaises(ValueError, ResultsKeepStrategy, count=count, keep=keep, lowest=use_lowest) |
|||
|
|||
def test_apply(self): |
|||
for item_set_name, items, ordered in self.items_pool: |
|||
item_count = len(items) |
|||
for discard_count, keep, use_lowest in itertools.product( |
|||
range(0, item_count+2), |
|||
(True, False), |
|||
(True, False) |
|||
): |
|||
with self.subTest(item_set=item_set_name, discard_count=discard_count, keep=keep, use_lowest=use_lowest): |
|||
s = ResultsKeepStrategy(discard_count, keep=keep, lowest=use_lowest) |
|||
discard_by_default = s.get_discarded_default(item_count) |
|||
self.assertEqual(discard_by_default, (discard_count >= item_count//2) != keep) |
|||
|
|||
roll_results = [RollResult(i, discard_by_default) for i in items] |
|||
total = s.apply(roll_results) |
|||
roll_results.sort(key=lambda r: (r.value, r.is_discarded == (keep == use_lowest))) |
|||
|
|||
self.assertTrue(all( |
|||
r.value == i for r, i in itertools.zip_longest(roll_results, ordered) |
|||
)) |
|||
|
|||
keep_lowest = (keep == use_lowest) |
|||
lowest_count = discard_count if use_lowest else max(0, item_count-discard_count) |
|||
kept_slice = ((0, lowest_count) if keep_lowest else (lowest_count, None)) |
|||
self.assertEqual(total, sum(itertools.islice(ordered, *kept_slice))) |
|||
self.assertTrue(all( |
|||
r.is_discarded == ((i < lowest_count) != keep_lowest) |
|||
for i, r in enumerate(roll_results) |
|||
)) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
@ -0,0 +1,84 @@ |
|||
import itertools |
|||
import random |
|||
import unittest |
|||
|
|||
from devpotato_bot.quickselect import partition, select |
|||
|
|||
|
|||
class ItemSelectionTest(unittest.TestCase): |
|||
@classmethod |
|||
def setUpClass(cls): |
|||
ordered = tuple(range(1, 11)) |
|||
shuffled = list(ordered) |
|||
random.seed(1024) |
|||
random.shuffle(shuffled) |
|||
random_items = tuple(random.randint(1, 10) for _ in range(10)) |
|||
cls.items_pool = ( |
|||
('reversed', tuple(reversed(ordered)), ordered), |
|||
('shuffled', tuple(shuffled), ordered), |
|||
('random', random_items, tuple(sorted(random_items))) |
|||
) |
|||
|
|||
def setUp(self) -> None: |
|||
random.seed(1024) |
|||
|
|||
def test_select_noop(self): |
|||
for n in [0, 1, 10]: |
|||
original_items = list(range(n)) |
|||
items = list(original_items) |
|||
select(items, 0) |
|||
self.assertEqual(original_items, items) |
|||
|
|||
items = list(original_items) |
|||
select(items, n) |
|||
self.assertEqual(original_items, items) |
|||
|
|||
self.assertRaises(ValueError, select, list([1, 2]), -1) |
|||
self.assertRaises(ValueError, select, list([1, 2]), 3) |
|||
|
|||
def test_select_sorted(self): |
|||
items = list(range(10)) |
|||
select(items, 4) |
|||
self.assertSequenceEqual(range(10), items) |
|||
|
|||
items = list(range(10)[::-1]) |
|||
select(items, 4, compare=int.__gt__) |
|||
self.assertSequenceEqual(range(10)[::-1], items) |
|||
|
|||
def test_select(self): |
|||
for item_set_name, items, ordered in self.items_pool: |
|||
for k in range(1, len(items)): |
|||
with self.subTest(item_set=item_set_name, k=k): |
|||
selected = list(items) |
|||
select(selected, k) |
|||
self.assertEqual(ordered[k-1], selected[k-1]) |
|||
self.assertSequenceEqual(ordered, sorted(selected[:k]) + sorted(selected[k:])) |
|||
|
|||
@staticmethod |
|||
def get_pivot_combinations(items): |
|||
n = len(items) |
|||
for left in range(n - 1): |
|||
for right in range(left, n): |
|||
for pivot_i in range(left, right + 1): |
|||
yield left, right, pivot_i |
|||
|
|||
def test_partition(self): |
|||
for item_set_name, original_items, _ in self.items_pool: |
|||
for left, right, pivot_i in self.get_pivot_combinations(original_items): |
|||
with self.subTest(item_set=item_set_name, left=left, right=right, pivot_i=pivot_i): |
|||
items = list(original_items) |
|||
new_pivot_i = partition(items, left, right, pivot_i, int.__lt__) |
|||
|
|||
# check ordering within [left; right] |
|||
pivot = original_items[pivot_i] |
|||
self.assertEqual(pivot, items[new_pivot_i]) |
|||
self.assertTrue(all(x < pivot for x in items[left:new_pivot_i])) |
|||
self.assertTrue(all(x >= pivot for x in items[new_pivot_i+1:right+1])) |
|||
|
|||
# order of items outside of [left; right] range is unchanged |
|||
self.assertSequenceEqual(items[:left], original_items[:left]) |
|||
self.assertSequenceEqual(items[right+1:], original_items[right+1:]) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue