import unittest
import sys
import relabeling

from sortedcontainers import SortedListWithKey
from six.moves import zip as izip, xrange

# Shortcut to keep code more concise.
r = relabeling

def skipfloats(x, n):
  for i in xrange(n):
    x = relabeling.nextfloat(x)
  return x


class Item(object):
  """
  Tests use Item for items of the sorted lists we maintain.
  """
  def __init__(self, value, key):
    self.value = value
    self.key = key

  def __repr__(self):
    return "Item(v=%s,k=%s)" % (self.value, self.key)


class ItemList(object):
  def __init__(self, val_key_pairs):
    self._slist = SortedListWithKey(key=lambda item: item.key)
    self._slist.update(Item(v, k) for (v, k) in val_key_pairs)
    self.num_update_events = 0
    self.num_updated_keys = 0

  def get_values(self):
    return [item.value for item in self._slist]

  def get_list(self):
    return self._slist

  def find_value(self, value):
    return next((item for item in self._slist if item.value == value), None)

  def avg_updated_keys(self):
    return float(self.num_updated_keys) / len(self._slist)

  def next(self, item):
    return self._slist[self._slist.index(item) + 1]

  def prev(self, item):
    return self._slist[self._slist.index(item) - 1]

  def insert_items(self, val_key_pairs, prepare_inserts=r.prepare_inserts):
    keys = [k for (v, k) in val_key_pairs]
    adjustments, new_keys = prepare_inserts(self._slist, keys)
    if adjustments:
      self.num_update_events += 1
      self.num_updated_keys += len(adjustments)

    # Updating items is a bit tricky: we have to do it without violating order (just changing
    # key of an existing item easily might), so we remove items first. And we can only rely on
    # indices if we scan items in a backwards order.
    items = [self._slist.pop(index) for (index, key) in reversed(adjustments)]
    items.reverse()
    for (index, key), item in izip(adjustments, items):
      item.key = key
    self._slist.update(items)

    # Now add the new items.
    self._slist.update(Item(val, new_key) for (val, _), new_key in izip(val_key_pairs, new_keys))

    # For testing, pass along the return value from prepare_inserts.
    return adjustments, new_keys


class TestRelabeling(unittest.TestCase):

  def test_nextfloat(self):
    def verify_nextfloat(x):
      nx = r.nextfloat(x)
      self.assertNotEqual(nx, x)
      self.assertGreater(nx, x)
      self.assertEqual(r.prevfloat(nx), x)
      average = (nx + x) / 2
      self.assertTrue(average == nx or average == x)

    verify_nextfloat(1)
    verify_nextfloat(-1)
    verify_nextfloat(417)
    verify_nextfloat(-417)
    verify_nextfloat(12312422)
    verify_nextfloat(-12312422)
    verify_nextfloat(0.1234)
    verify_nextfloat(-0.1234)
    verify_nextfloat(0.00005)
    verify_nextfloat(-0.00005)
    verify_nextfloat(0.0)
    verify_nextfloat(r.nextfloat(0.0))
    verify_nextfloat(sys.float_info.min)
    verify_nextfloat(-sys.float_info.min)

  def test_prevfloat(self):
    def verify_prevfloat(x):
      nx = r.prevfloat(x)
      self.assertNotEqual(nx, x)
      self.assertLess(nx, x)
      self.assertEqual(r.nextfloat(nx), x)
      average = (nx + x) / 2
      self.assertTrue(average == nx or average == x)

    verify_prevfloat(1)
    verify_prevfloat(-1)
    verify_prevfloat(417)
    verify_prevfloat(-417)
    verify_prevfloat(12312422)
    verify_prevfloat(-12312422)
    verify_prevfloat(0.1234)
    verify_prevfloat(-0.1234)
    verify_prevfloat(0.00005)
    verify_prevfloat(-0.00005)
    verify_prevfloat(r.nextfloat(0.0))
    verify_prevfloat(sys.float_info.min)
    verify_prevfloat(-sys.float_info.min)

  def test_range_around_float(self):

    def verify_range(bits, begin, end):
      self.assertEqual(r.range_around_float(begin, bits), (begin, end))
      self.assertEqual(r.range_around_float((end + begin) / 2, bits), (begin, end))
      delta = r.nextfloat(begin) - begin
      if begin + delta < end:
        self.assertEqual(r.range_around_float(begin + delta, bits), (begin, end))
      if end - delta >= begin:
        self.assertEqual(r.range_around_float(end - delta, bits), (begin, end))

    def verify_small_range_at(begin):
      verify_range(0, begin, skipfloats(begin, 1))
      verify_range(1, begin, skipfloats(begin, 2))
      verify_range(4, begin, skipfloats(begin, 16))
      verify_range(10, begin, skipfloats(begin, 1024))

    verify_small_range_at(1.0)
    verify_small_range_at(0.5)
    verify_small_range_at(0.25)
    verify_small_range_at(0.75)
    verify_small_range_at(17.0)

    verify_range(52, 1.0, 2.0)
    self.assertEqual(r.range_around_float(1.4, 52), (1.0, 2.0))

    verify_range(52, 0.5, 1.0)
    self.assertEqual(r.range_around_float(0.75, 52), (0.5, 1.0))

    self.assertEqual(r.range_around_float(17, 48), (17.0, 18.0))
    self.assertEqual(r.range_around_float(17, 49), (16.0, 18.0))
    self.assertEqual(r.range_around_float(17, 50), (16.0, 20.0))
    self.assertEqual(r.range_around_float(17, 51), (16.0, 24.0))
    self.assertEqual(r.range_around_float(17, 52), (16.0, 32.0))

    verify_range(51, 0.25, 0.375)
    self.assertEqual(r.range_around_float(0.27, 51), (0.25, 0.375))
    self.assertEqual(r.range_around_float(0.30, 51), (0.25, 0.375))
    self.assertEqual(r.range_around_float(0.37, 51), (0.25, 0.375))

    verify_range(51, 0.50, 0.75)
    verify_range(51, 0.75, 1.0)
    verify_range(52, 0.25, 0.5)

    # Range around 0 isn't quite right, and possibly can't be. But we test that it's at least
    # something meaningful.
    self.assertEqual(r.range_around_float(0.00, 52), (0.00, 0.5))
    self.assertEqual(r.range_around_float(0.25, 52), (0.25, 0.5))

    self.assertEqual(r.range_around_float(0.00, 50), (0.00, 0.125))
    self.assertEqual(r.range_around_float(0.10, 50), (0.09375, 0.109375))

    self.assertEqual(r.range_around_float(0.0, 53), (0.00, 1))
    self.assertEqual(r.range_around_float(0.5, 53), (0.00, 1))

    self.assertEqual(r.range_around_float(0, 0), (0.0, skipfloats(0.5, 1) - 0.5))
    self.assertEqual(r.range_around_float(0, 1), (0.0, skipfloats(0.5, 2) - 0.5))
    self.assertEqual(r.range_around_float(0, 4), (0.0, skipfloats(0.5, 16) - 0.5))
    self.assertEqual(r.range_around_float(0, 10), (0.0, skipfloats(0.5, 1024) - 0.5))

  def test_all_distinct(self):

    # Just like r.get_range, but includes endpoints.
    def full_range(start, end, count):
      return [start] + r.get_range(start, end, count) + [end]

    self.assertTrue(r.all_distinct(range(1000)))
    self.assertTrue(r.all_distinct([]))
    self.assertTrue(r.all_distinct([1.0]))
    self.assertFalse(r.all_distinct([1.0, 1.0]))

    self.assertTrue(r.all_distinct(full_range(0, 1, 1000)))
    self.assertFalse(r.all_distinct(full_range(1.0, r.nextfloat(1.0), 1)))
    self.assertFalse(r.all_distinct(full_range(1.0, skipfloats(1.0, 10), 10)))
    self.assertTrue(r.all_distinct(full_range(1.0, skipfloats(1.0, 11), 10)))
    self.assertTrue(r.all_distinct(full_range(0.1, skipfloats(0.1, 100), 99)))
    self.assertFalse(r.all_distinct(full_range(0.1, skipfloats(0.1, 100), 100)))

  def test_get_range(self):
    self.assertEqual(r.get_range(0.0, 2.0, 3), [0.5, 1, 1.5])
    self.assertEqual(r.get_range(1, 17, 7), [3,5,7,9,11,13,15])
    self.assertEqual(r.get_range(-1, 1.5, 4), [-0.5, 0, 0.5, 1])

  def test_prepare_inserts_simple(self):
    slist = SortedListWithKey(key=lambda i: i.key)
    self.assertEqual(r.prepare_inserts(slist, [4.0]), ([], [1.0]))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [1.0]))
    self.assertEqual(r.prepare_inserts(slist, [4.0, 4.0, 5, 6]), ([], [1.0, 2.0, 3.0, 4.0]))
    self.assertEqual(r.prepare_inserts(slist, [4, 5, 6, 5, 4]), ([], [1,3,5,4,2]))
    slist.update(Item(v, k) for (v, k) in zip(['a','b','c'], [3.0, 4.0, 5.0]))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [1.5]))

    values = 'defgijkl'
    to_update, to_add = r.prepare_inserts(slist, [3,3,4,5,6,4,6,4])
    self.assertEqual(to_add, [1., 2., 3.25, 4.5, 6., 3.5, 7., 3.75])
    self.assertEqual(to_update, [])
    slist.update(Item(v, k) for (v, k) in zip(values, to_add))
    self.assertEqual([i.value for i in slist], list('deafjlbgcik'))

  def test_with_invalid(self):
    slist = SortedListWithKey(key=lambda i: i.key)
    slist.add(Item('a', 0))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 2.0)], [1.0]))
    self.assertEqual(r.prepare_inserts(slist, [1.0]), ([], [1.0]))

    slist = SortedListWithKey(key=lambda i: i.key)
    slist.update(Item(v, k) for (v, k) in zip('abcdef', [0, 0, 0, 1, 1, 1]))
    # We expect the whole range to be renumbered.
    self.assertEqual(r.prepare_inserts(slist, [0.0, 0.0]),
                     ([(0, 3.0), (1, 4.0), (2, 5.0), (3, 6.0), (4, 7.0), (5, 8.0)],
                      [1.0, 2.0]))

    # We also expect a renumbering if there are negative or infinite values.
    slist = SortedListWithKey(key=lambda i: i.key)
    slist.add(Item('a', float('inf')))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 2.0)], [1.0]))
    self.assertEqual(r.prepare_inserts(slist, [float('inf')]), ([(0, 2.0)], [1.0]))

    slist = SortedListWithKey(key=lambda i: i.key)
    slist.add(Item('a', -17.0))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([(0, 1.0)], [2.0]))
    self.assertEqual(r.prepare_inserts(slist, [float('-inf')]), ([(0, 2.0)], [1.0]))

  def test_with_dups(self):
    slist = SortedListWithKey(key=lambda i: i.key)
    slist.update(Item(v, k) for (v, k) in zip('abcdef', [1, 1, 1, 2, 2, 2]))
    self.assertEqual(r.prepare_inserts(slist, [0.0]), ([], [0.5]))

  def test_renumber_endpoints1(self):
    self._do_test_renumber_ends([])

  def test_renumber_endpoints2(self):
    self._do_test_renumber_ends(list(zip("abcd", [40,50,60,70])))

  def _do_test_renumber_ends(self, initial):
    # Test insertions that happen together on the left and on the right.
    slist = ItemList(initial)
    for i in xrange(2000):
      slist.insert_items([(i, float('-inf')), (-i, float('inf'))])

    self.assertEqual(slist.get_values(),
                     rev_range(2000) + [v for v,k in initial] + list(xrange(0, -2000, -1)))
    #print slist.num_update_events, slist.num_updated_keys
    self.assertLess(slist.avg_updated_keys(), 3)
    self.assertLess(slist.num_update_events, 80)

  def test_renumber_left(self):
    slist = ItemList(zip("abcd", [4,5,6,7]))
    ins_item = slist.find_value('c')
    for i in xrange(1000):
      slist.insert_items([(i, ins_item.key)])

    # Check the end result
    self.assertEqual(slist.get_values(), ['a', 'b'] + list(xrange(1000)) + ['c', 'd'])
    self.assertAlmostEqual(slist.avg_updated_keys(), 3.5, delta=1)
    self.assertLess(slist.num_update_events, 40)

  def test_renumber_right(self):
    slist = ItemList(zip("abcd", [4,5,6,7]))
    ins_item = slist.find_value('b')
    for i in xrange(1000):
      slist.insert_items([(i, r.nextfloat(ins_item.key))])

    # Check the end result
    self.assertEqual(slist.get_values(), ['a', 'b'] + rev_range(1000) + ['c', 'd'])
    self.assertAlmostEqual(slist.avg_updated_keys(), 3.5, delta=1)
    self.assertLess(slist.num_update_events, 40)

  def test_renumber_left_dumb(self):
    # Here we use the "dumb" approach, and see that in our test case it's significantly worse.
    # (The badness increases with the number of insertions, but we'll keep numbers small to keep
    # the test fast.)
    slist = ItemList(zip("abcd", [4,5,6,7]))
    ins_item = slist.find_value('c')
    for i in xrange(1000):
      slist.insert_items([(i, ins_item.key)], prepare_inserts=r.prepare_inserts_dumb)
    self.assertEqual(slist.get_values(), ['a', 'b'] + list(xrange(1000)) + ['c', 'd'])
    self.assertGreater(slist.avg_updated_keys(), 8)

  def test_renumber_right_dumb(self):
    slist = ItemList(zip("abcd", [4,5,6,7]))
    ins_item = slist.find_value('b')
    for i in xrange(1000):
      slist.insert_items([(i, r.nextfloat(ins_item.key))], prepare_inserts=r.prepare_inserts_dumb)
    self.assertEqual(slist.get_values(), ['a', 'b'] + rev_range(1000) + ['c', 'd'])
    self.assertGreater(slist.avg_updated_keys(), 8)

  def test_renumber_multiple(self):
    # In this test, we make multiple difficult insertions at each step: to the left and to the
    # right of each value. This should involve some adjustments that get affected by subsequent
    # adjustments during the same prepare_inserts() call.
    slist = ItemList(zip("abcd", [4,5,6,7]))
    # We insert items on either side of each of the original items (a, b, c, d).
    ins_items = list(slist.get_list())
    N = 250
    for i in xrange(N):
      slist.insert_items([("%sr%s" % (x.value, i), r.nextfloat(x.key)) for x in ins_items] +
                         [("%sl%s" % (x.value, i), x.key) for x in ins_items] +
                         # After the first insertion, also insert items next on either side of the
                         # neighbors of the original a, b, c, d items.
                         ([("%sR%s" % (x.value, i), r.nextfloat(slist.next(x).key))
                           for x in ins_items] +
                          [("%sL%s" % (x.value, i), slist.prev(x).key) for x in ins_items]
                          if i > 0 else []))

    # The list should grow like this:
    #  a, b, c, d
    #  al0, a, ar0, ... (same for b, c, d)
    #  aL1, al0, al1, a, ar1, ar0, aR1, ...
    #  aL1, al0, aL2, al1, al2, a, ar2, ar1, aR2, ar0, aR1, ...
    def left_half(val):
      half = list(xrange(2*N - 1))
      half[0::2] = ['%sL%d' % (val, i) for i in xrange(1, N + 1)]
      half[1::2] = ['%sl%d' % (val, i) for i in xrange(0, N - 1)]
      half[-1] = '%sl%d' % (val, N - 1)
      return half

    def right_half(val):
      # Best described as the reverse of left_half
      return [v.replace('l', 'r').replace('L', 'R') for v in reversed(left_half(val))]

    # The list we expect to see is of the form [aL1, al1, aL2, al2, ... aL1000, al1000, a,
    # ar1000, aR1000, ..., aR1],
    # followed by the same sequence for b, c, and d.
    self.assertEqual(slist.get_values(), sum([left_half(v) + [v] + right_half(v)
                                              for v in ('a', 'b', 'c', 'd')], []))

    self.assertAlmostEqual(slist.avg_updated_keys(), 2.5, delta=1)
    self.assertLess(slist.num_update_events, 40)


def rev_range(n):
  return list(reversed(list(xrange(n))))

if __name__ == "__main__":
  unittest.main()