Browse Source

Add discard/keep modifier to dice roll, improve dice roll tests

archive/feature/daily-titles-log
Vladislav Glinsky 6 years ago
parent
commit
e840fe1fb6
Signed by: cl0ne GPG Key ID: 9D058DD29491782E
  1. 5
      devpotato_bot/commands/roll.py
  2. 93
      devpotato_bot/dice_parser.py
  3. 47
      devpotato_bot/quickselect.py
  4. 257
      tests/test_dice_parser.py
  5. 84
      tests/test_quickselect.py

5
devpotato_bot/commands/roll.py

@ -69,7 +69,10 @@ def command_callback(update: Update, context: CallbackContext):
roll_total, single_rolls, was_limited = dice.get_result(item_limit=10)
lines.extend((
'\\(',
' \\+ '.join(str(r) for r in single_rolls)
' \\+ '.join(
f'~{r.value}~' if r.is_discarded else f'*__{r.value}__*'
for r in single_rolls
)
))
if was_limited:
lines.append(' \\+ ⋯ ')

93
devpotato_bot/dice_parser.py

@ -1,7 +1,10 @@
import itertools
import random
import re
from typing import List, Tuple, Optional
from devpotato_bot.quickselect import select
class ParseError(Exception):
def __init__(self):
@ -20,17 +23,69 @@ class ValueRangeError(ValueError):
})
class RollResult:
def __init__(self, value, is_discarded=False):
self.value = value
self.is_discarded = is_discarded
class ResultsKeepStrategy:
def __init__(self, count, *, keep, lowest):
if count < 0:
raise ValueError('count should be non-negative')
self.count = count
self.keep = keep
self.lowest = lowest
def get_discarded_default(self, roll_count):
# self.keep if self.count < roll_count/2 else not self.keep
return self.keep == (self.count < roll_count - self.count)
@staticmethod
def _compare_lowest(a: RollResult, b: RollResult):
return a.value < b.value
@staticmethod
def _compare_highest(a: RollResult, b: RollResult):
return a.value > b.value
def apply(self, results: List[RollResult]) -> int:
remaining_items = max(0, len(results) - self.count)
# Always select lower part
if self.count < remaining_items:
count = self.count
discard_selected = not self.keep
select_lowest = self.lowest
else:
count = remaining_items
discard_selected = self.keep
select_lowest = not self.lowest
comparator = self._compare_lowest if select_lowest else self._compare_highest
select(results, count, compare=comparator)
for r in itertools.islice(results, count):
r.is_discarded = discard_selected
kept_range = (count, None) if discard_selected else (0, count)
return sum(r.value for r in itertools.islice(results, *kept_range))
# default strategy is "discard none"
ResultsKeepStrategy.DEFAULT = ResultsKeepStrategy(0, keep=False, lowest=False)
class Dice:
__regex = re.compile(
r'(?P<rolls>\d+)?'
r'd'
r'(?P<sides>\d+|%)'
r'(?:(?P<keep_or_discard>[+-])?(?P<discard_type>[LH])(?P<discard_count>\d+))?'
r'(?:(?P<modifier_sign>[+-])(?P<modifier>\d+))?'
)
BIGGEST_DICE = 120
ROLL_LIMIT = 100
def __init__(self, rolls, sides, modifier=0):
def __init__(self, rolls, sides, *,
discard_strategy: ResultsKeepStrategy = ResultsKeepStrategy.DEFAULT,
modifier=0):
if not(0 < rolls <= self.ROLL_LIMIT):
raise ValueRangeError('roll count', rolls, (1, self.ROLL_LIMIT))
if not(0 < sides <= self.BIGGEST_DICE):
@ -38,36 +93,44 @@ class Dice:
self.rolls = rolls
self.sides = sides
self.modifier = modifier or 0
self.discard_strategy = discard_strategy
@staticmethod
def parse(roll_str):
match = Dice.__regex.fullmatch(roll_str)
if not match:
raise ParseError
rolls, sides, modifier_sign, modifier = match.groups()
(
rolls, sides,
keep_or_discard, discard_type, discard_count,
modifier_sign, modifier
) = match.groups()
rolls = int(rolls) if rolls else 1
sides = int(sides) if sides != '%' else 100
discard_strategy = ResultsKeepStrategy.DEFAULT
if discard_count:
discard_count = min(rolls, int(discard_count))
use_lowest = discard_type == 'L'
keep = keep_or_discard is None or keep_or_discard == '+'
discard_strategy = ResultsKeepStrategy(discard_count, keep=keep, lowest=use_lowest)
modifier = int(modifier) if modifier else 0
if modifier and modifier_sign == '-':
modifier = -modifier
return Dice(rolls, sides, modifier)
return Dice(rolls, sides, discard_strategy=discard_strategy, modifier=modifier)
def _single_roll(self):
return random.randint(1, self.sides)
def get_result(self, item_limit=None) -> Tuple[int, List[int], bool]:
total = self.modifier
items = []
def get_results(self, item_limit=None) -> Tuple[int, List[RollResult], bool]:
if item_limit is None:
item_count = self.rolls
return_count = self.rolls
else:
item_count = min(self.rolls, item_limit)
if item_limit < 0:
raise ValueError('item limit should be non-negative!')
for i in range(item_count):
r = self._single_roll()
total += r
items.append(r)
if item_limit is not None:
total += sum(self._single_roll() for _ in range(item_limit, self.rolls))
return total, items, item_count < self.rolls
return_count = min(self.rolls, item_limit)
discarded_default = self.discard_strategy.get_discarded_default(self.rolls)
results = [RollResult(self._single_roll(), is_discarded=discarded_default)
for _ in range(self.rolls)]
return_results = results[:return_count]
total = self.discard_strategy.apply(results) + self.modifier
return total, return_results, return_count < self.rolls

47
devpotato_bot/quickselect.py

@ -0,0 +1,47 @@
import operator
import random
def partition(items, left, right, pivot_index, compare):
"""Groups items between left and right index into two parts:
those less than element at pivot_index, and those greater than or equal to it
"""
assert 0 <= left <= pivot_index <= right < len(items)
if left == right:
return pivot_index
pivot_value = items[pivot_index]
items[right], items[pivot_index] = items[pivot_index], items[right]
pivot_index = left
for i in range(left, right):
if compare(items[i], pivot_value):
items[pivot_index], items[i] = items[i], items[pivot_index]
pivot_index += 1
items[right], items[pivot_index] = items[pivot_index], items[right]
return pivot_index
def select(items, k, compare=operator.lt):
"""Reorders items in-place so that first k items are the lowest (largest) ones
A custom compare function can be supplied to customize comparison of the items.
The function uses quickselect algorithm with minor modifications."""
if not items or k == 0 or k == len(items):
return
if k < 0:
raise ValueError('k should not be less than zero')
if k > len(items):
raise ValueError('k should not be greater than item count')
left = 0
right = len(items) - 1
while True:
if left == right:
break
pivot_index = random.randint(left, right)
pivot_index = partition(items, left, right, pivot_index, compare)
if (k - 1) == pivot_index:
break
elif (k - 1) < pivot_index:
right = pivot_index - 1
else:
left = pivot_index + 1

257
tests/test_dice_parser.py

@ -1,147 +1,188 @@
import itertools
import unittest
from unittest import mock
from devpotato_bot.dice_parser import Dice, ParseError, ValueRangeError
from devpotato_bot.dice_parser import Dice, ParseError, ValueRangeError, ResultsKeepStrategy, RollResult
class DiceParserTest(unittest.TestCase):
@staticmethod
def _DiceVars(dice):
return dice.rolls, dice.sides, dice.modifier
def test_basic_notation(self):
for roll_str, expected in (
('1d6', (1, 6, 0)),
('d6', (1, 6, 0)),
('d5', (1, 5, 0)),
('5d1', (5, 1, 0))
):
with self.subTest(roll_str=roll_str):
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected)
class ResultsKeepStrategyMock:
def __init__(self, total, discard):
self._expected_total = total
self._discarded_default = discard
for roll_str in ('d', '1d', '-1d', '-1d6'):
with self.subTest(roll_str=roll_str):
self.assertRaises(ParseError, Dice.parse, roll_str)
def get_discarded_default(self, _):
return self._discarded_default
def apply(self, results) -> int:
if ResultsKeepStrategyMock.DEFAULT is self:
print('apply')
# ensure that actual strategy gets own copy of items
results.clear()
return self._expected_total
for roll_str in ('0d6', '1d0', 'd0', '0d0'):
with self.subTest(roll_str=roll_str):
self.assertRaises(ValueRangeError, Dice.parse, roll_str)
def test_modifier(self):
for roll_str, expected in (
('d6-5', (1, 6, -5)),
('d6+5', (1, 6, 5)),
('d6+0', (1, 6, 0)),
('2d10-1', (2, 10, -1))
):
with self.subTest(roll_str=roll_str):
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected)
for roll_str in ('d+6', '1d-6', '-6d-1', 'd6+', 'd6-', 'd6-+6', 'd6+-6', 'd6-+-6', 'd6+d6'):
ResultsKeepStrategyMock.DEFAULT = ResultsKeepStrategyMock(-1, object())
class DiceTest(unittest.TestCase):
def test_invalid_patterns(self):
for roll_str in (
'd', '1d', '-1d', '-1d6', 'd6+', 'd6-',
'd+6', '1d-6', '-6d-1', 'd6-+6', 'd6+-6', 'd6-+-6', 'd6+d6',
'%d', '1%d', '-%1d', '-1%d6', '-1d%', 'd%%', '2d%%',
'd6L', 'd6H', 'dH1', 'dL1', 'd6L-1', 'd6H-1'
):
with self.subTest(roll_str=roll_str):
self.assertRaises(ParseError, Dice.parse, roll_str)
def test_percentile_dice(self):
for roll_str, expected in (
('d%', (1, 100, 0)),
('d%-5', (1, 100, -5)),
('d%+5', (1, 100, 5)),
('d%+0', (1, 100, 0)),
('2d%-1', (2, 100, -1)),
('2d%+1', (2, 100, 1))
):
for roll_str in ('0d6', '1d0', 'd0', '0d0', '0d%'):
with self.subTest(roll_str=roll_str):
self.assertEqual(self._DiceVars(Dice.parse(roll_str)), expected)
self.assertRaises(ValueRangeError, Dice.parse, roll_str)
for roll_str in ('%d', '1%d', '-%1d', '-1%d6', '-1d%', 'd%%', '2d%%'):
def test_valid_patterns(self):
roll_counts = [('', 1), ('1', 1), ('2', 2), ('10', 10), ('05', 5)]
side_counts = [
('d6', 6), ('d20', 20),
('d%', 100), ('d100', 100),
('d1', 1), ('d2', 2), ('d03', 3)
]
modifiers = [
('', 0), ('+0', 0),
('+1', 1), ('-1', -1),
('+05', 5), ('-05', -5)
]
discards = itertools.chain(
[('', (0, False, False))],
(
(f'{prefix}{t}{count}', (count, prefix != '-', t == 'L'))
for prefix, t, count
in itertools.product(['', '+', '-'], 'LH', [0, 1, 2, 10, 20])
)
)
for i in itertools.product(roll_counts, side_counts, discards, modifiers):
roll_str, expected_vars = zip(*i)
roll_str = ''.join(roll_str)
with self.subTest(roll_str=roll_str):
self.assertRaises(ParseError, Dice.parse, roll_str)
rolls, sides, discard_vars, modifier = expected_vars
d = Dice.parse(roll_str)
self.assertEqual(d.sides, sides)
self.assertEqual(d.rolls, rolls)
self.assertEqual(d.modifier, modifier)
self.assertRaises(ValueRangeError, Dice.parse, '0d%')
discard_count, discard_keep, discard_lowest = discard_vars
self.assertEqual(d.discard_strategy.count, min(discard_count, rolls))
self.assertEqual(d.discard_strategy.keep, discard_keep)
self.assertEqual(d.discard_strategy.lowest, discard_lowest)
def test_init(self):
for rolls, sides, modifier in (
(1, 1, 0),
(1, 6, 0),
(1, 10, 0),
(100, 6, 0),
(100, 1, 0),
(1, 120, 0),
(100, 120, 0)
):
with self.subTest(rolls=rolls, sides=sides):
d = Dice(rolls, sides)
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier))
with self.subTest(rolls=rolls, sides=sides, modifier=None):
d = Dice(rolls, sides, None)
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier))
with self.subTest(rolls=rolls, sides=sides, modifier=modifier):
d = Dice(rolls, sides, modifier)
self.assertEqual(self._DiceVars(d), (rolls, sides, modifier))
for rolls, sides, modifier in (
(1, 1, -1),
(1, 1, 1),
(1, 6, 6),
(1, 6, -6)
):
with self.subTest(rolls=rolls, sides=sides, modifier=modifier):
d = Dice(rolls, sides, modifier)
self.assertEqual(d.rolls, rolls)
self.assertEqual(d.sides, sides)
self.assertEqual(Dice(1, 6).modifier, 0)
self.assertEqual(Dice(1, 6, modifier=None).modifier, 0)
self.assertEqual(Dice(1, 6).discard_strategy, ResultsKeepStrategy.DEFAULT)
for rolls, sides in (
(0, 0),
(0, 6),
(1, 0),
(-1, 6),
(-1, -6),
(1, -6),
(1, 121),
(100, 121),
(120, 120),
(120, 6),
(120, 100)
):
invalid_roll_counts = itertools.product(
(-1, 0, Dice.ROLL_LIMIT + 1),
(-1, 0, 6, 100, Dice.BIGGEST_DICE, Dice.BIGGEST_DICE + 1)
)
invalid_sides = itertools.product(
(1, 10, Dice.ROLL_LIMIT),
(-1, 0, Dice.BIGGEST_DICE + 1)
)
for rolls, sides in itertools.chain(invalid_roll_counts, invalid_sides):
with self.subTest(rolls=rolls, sides=sides):
self.assertRaises(ValueRangeError, Dice, rolls, sides)
@mock.patch('random.randint', return_value=1)
@mock.patch('random.randint')
def test_get_result(self, randint):
for roll_count, sides, modifier in (
(1, 2, 0), (3, 6, 0), (3, 6, -1), (3, 6, +1), (5, 10, -10), (5, 10, +10),
(1, 100, 0), (5, 100, -5), (10, 100, 0), (2, 100, 10)
discard_strategy = ResultsKeepStrategyMock(1337, discard=object())
for (roll_count, sides), modifier in itertools.product(
[(1, 1), (1, 6), (10, 20)],
[0, -1, +1]
):
with self.subTest(rolls=roll_count, sides=sides, modifier=modifier):
d = Dice(roll_count, sides, modifier)
d = Dice(roll_count, sides, modifier=modifier, discard_strategy=discard_strategy)
expected_calls = [mock.call(1, sides)] * roll_count
expected_total = roll_count + modifier
expected_rolls = [1] * roll_count
for item_limit in (
None, 0, 1, roll_count - 1, roll_count, roll_count + 1, roll_count + 20
None, 0, 1, roll_count - 1, roll_count, roll_count + 1, roll_count + 20
):
randint.side_effect = range(roll_count)
with self.subTest(item_limit=item_limit):
roll_total, single_rolls, was_limited = d.get_result(item_limit)
roll_total, single_rolls, was_limited = d.get_results(item_limit)
randint.assert_has_calls(expected_calls)
self.assertEqual(randint.call_count, roll_count)
self.assertEqual(roll_total, expected_total)
limit_is_set = item_limit is not None and roll_count > item_limit
self.assertEqual(was_limited, limit_is_set)
roll_subset = expected_rolls[:(item_limit if limit_is_set else roll_count)]
self.assertSequenceEqual(single_rolls, roll_subset)
self.assertEqual(roll_total, discard_strategy._expected_total + modifier)
should_limit = item_limit is not None and roll_count > item_limit
self.assertEqual(was_limited, should_limit)
expected_rolls = range(item_limit if should_limit else roll_count)
self.assertTrue(all(
r.value == e
for r, e in itertools.zip_longest(single_rolls, expected_rolls)
))
self.assertTrue(all(
r.is_discarded is discard_strategy._discarded_default
for r in single_rolls
))
randint.reset_mock()
for item_limit in (-1, -roll_count, -roll_count-1):
with self.subTest(item_limit=item_limit):
self.assertRaises(ValueError, d.get_result, item_limit)
self.assertRaises(ValueError, d.get_results, item_limit)
randint.assert_not_called()
randint.reset_mock()
class ResultsKeepStrategyTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
ordered = tuple(range(1, 11))
shuffled = list(ordered)
import random
random.seed(1024)
random.shuffle(shuffled)
random_items = tuple(random.randint(1, 10) for _ in range(10))
cls.items_pool = (
('reversed', tuple(reversed(ordered)), ordered),
('shuffled', tuple(shuffled), ordered),
('random', random_items, tuple(sorted(random_items)))
)
def test_init(self):
for count, keep, use_lowest in itertools.product(
(-1, -2), (True, False), (True, False)
):
self.assertRaises(ValueError, ResultsKeepStrategy, count=count, keep=keep, lowest=use_lowest)
def test_apply(self):
for item_set_name, items, ordered in self.items_pool:
item_count = len(items)
for discard_count, keep, use_lowest in itertools.product(
range(0, item_count+2),
(True, False),
(True, False)
):
with self.subTest(item_set=item_set_name, discard_count=discard_count, keep=keep, use_lowest=use_lowest):
s = ResultsKeepStrategy(discard_count, keep=keep, lowest=use_lowest)
discard_by_default = s.get_discarded_default(item_count)
self.assertEqual(discard_by_default, (discard_count >= item_count//2) != keep)
roll_results = [RollResult(i, discard_by_default) for i in items]
total = s.apply(roll_results)
roll_results.sort(key=lambda r: (r.value, r.is_discarded == (keep == use_lowest)))
self.assertTrue(all(
r.value == i for r, i in itertools.zip_longest(roll_results, ordered)
))
keep_lowest = (keep == use_lowest)
lowest_count = discard_count if use_lowest else max(0, item_count-discard_count)
kept_slice = ((0, lowest_count) if keep_lowest else (lowest_count, None))
self.assertEqual(total, sum(itertools.islice(ordered, *kept_slice)))
self.assertTrue(all(
r.is_discarded == ((i < lowest_count) != keep_lowest)
for i, r in enumerate(roll_results)
))
if __name__ == '__main__':
unittest.main()

84
tests/test_quickselect.py

@ -0,0 +1,84 @@
import itertools
import random
import unittest
from devpotato_bot.quickselect import partition, select
class ItemSelectionTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
ordered = tuple(range(1, 11))
shuffled = list(ordered)
random.seed(1024)
random.shuffle(shuffled)
random_items = tuple(random.randint(1, 10) for _ in range(10))
cls.items_pool = (
('reversed', tuple(reversed(ordered)), ordered),
('shuffled', tuple(shuffled), ordered),
('random', random_items, tuple(sorted(random_items)))
)
def setUp(self) -> None:
random.seed(1024)
def test_select_noop(self):
for n in [0, 1, 10]:
original_items = list(range(n))
items = list(original_items)
select(items, 0)
self.assertEqual(original_items, items)
items = list(original_items)
select(items, n)
self.assertEqual(original_items, items)
self.assertRaises(ValueError, select, list([1, 2]), -1)
self.assertRaises(ValueError, select, list([1, 2]), 3)
def test_select_sorted(self):
items = list(range(10))
select(items, 4)
self.assertSequenceEqual(range(10), items)
items = list(range(10)[::-1])
select(items, 4, compare=int.__gt__)
self.assertSequenceEqual(range(10)[::-1], items)
def test_select(self):
for item_set_name, items, ordered in self.items_pool:
for k in range(1, len(items)):
with self.subTest(item_set=item_set_name, k=k):
selected = list(items)
select(selected, k)
self.assertEqual(ordered[k-1], selected[k-1])
self.assertSequenceEqual(ordered, sorted(selected[:k]) + sorted(selected[k:]))
@staticmethod
def get_pivot_combinations(items):
n = len(items)
for left in range(n - 1):
for right in range(left, n):
for pivot_i in range(left, right + 1):
yield left, right, pivot_i
def test_partition(self):
for item_set_name, original_items, _ in self.items_pool:
for left, right, pivot_i in self.get_pivot_combinations(original_items):
with self.subTest(item_set=item_set_name, left=left, right=right, pivot_i=pivot_i):
items = list(original_items)
new_pivot_i = partition(items, left, right, pivot_i, int.__lt__)
# check ordering within [left; right]
pivot = original_items[pivot_i]
self.assertEqual(pivot, items[new_pivot_i])
self.assertTrue(all(x < pivot for x in items[left:new_pivot_i]))
self.assertTrue(all(x >= pivot for x in items[new_pivot_i+1:right+1]))
# order of items outside of [left; right] range is unchanged
self.assertSequenceEqual(items[:left], original_items[:left])
self.assertSequenceEqual(items[right+1:], original_items[right+1:])
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save