Last Updated: 2/6/2024, 5:44:57 AM

# マージソート

# 概要

Wikipedia の git アニメーションがわかりやすいです。

# コード

2段階に分けられます。1つ目は、2つに分けてソートする段階。 2つ目は、分けてソートしたリストを統合する段階。

# 1. 簡単なもの

要素の参照に当たって pop, append を使い添字表記 lst[i] を使わない書き方です。 連結リスト (opens new window) を想定しています。

#
# 対話モード >>> に
# コピペで実行できます。
#
def linked_merge_sort(lst):
    if len(lst) == 1:
        return
    
    #
    # 1. 分割
    #
    q, r = divmod(len(lst), 2)
    left = [lst.pop() for _ in range(q)]
    right = [lst.pop() for _ in range(q + r)]
    
    #
    # 2. ソート
    #
    linked_merge_sort(left)
    linked_merge_sort(right)
    
    #
    # 3. 統合
    #
    left.reverse()  # 補足
    right.reverse()
    while left and right:
        if left[-1] <= right[-1]:
            lst.append(left.pop())
        else:
            lst.append(right.pop())
    while left:
        lst.append(left.pop())
    while right:
        lst.append(right.pop())

# ソート
lst = [3333, 5123, 9981, 1243, 7412]
linked_merge_sort(lst)
print(lst)
# [1243, 3333, 5123, 7412, 9981]

# 補足

これは insert(0, value) はしないため。 死ぬほど遅いから。Python の落とし穴。

これを避けるには deque を使えばいいのですが、 それはそれでコードが複雑になるので、 両者を比較、勘案してこのままにしました。

# 2. 難しいもの

pop, append を使わず添字表記 lst[i] による方法です。 静的な配列 (opens new window) を想定しています。

上のやり方だと pop, append を自由にしていてリストが伸び縮みしています。 メモリをどれくらい消費しているか全く見当がつきません。

このコードだと memory = [None] * len(lst) 分のメモリが必要なことがわかります。 反対にヒープソートやクイックソートではそのようなリストはありません、省メモリということです。

#
# 対話モード >>> に
# コピペで実行できます。
#
def merge_sort(lst):
    # ポイント
    #   マージソートは速いけど
    #   len(lst) 個分のメモリ memory を必要とします。
    memory = [None] * len(lst)  # <--- ポイント
    begin = 0
    end = len(lst) - 1
    _merge_sort(lst, memory, begin, end)


def _merge_sort(lst, memory, begin, end):
    if begin == end:
        return
    
    #
    # 1. ソート
    #
    mid = begin + (end - begin) // 2
    _merge_sort(lst, memory, begin, mid)
    _merge_sort(lst, memory, mid + 1, end)
    
    #
    # 2. コピー
    #
    for i in range(begin, end + 1):
        memory[i] = lst[i]
    
    #
    # 3. 結合
    #
    left, index, right = begin, begin, mid + 1
    while (left <= mid) and (right <= end):
        if memory[left] <= memory[right]:
            lst[index] = memory[left]
            left = left + 1
        else:
            lst[index] = memory[right]
            right = right + 1
        index = index + 1
    
    while left <= mid:
        lst[index] = memory[left]
        left = left + 1
        index = index + 1
    
    while right <= mid:
        lst[index] = memory[right]
        right = right + 1
        index = index + 1

# ソート
lst = [3333, 5123, 9981, 1243, 7412]
merge_sort(lst)
print(lst)
# [1243, 3333, 5123, 7412, 9981]

# 途中結果を表示する。

「簡単なもの」は途中結果を表示する方法がぱっと思いつかなかったので、 「難しいもの」の途中結果を表示します。

#
# 対話モード >>> に
# コピペで実行できます。
#
def merge_sort(lst):
    # ポイント
    #   マージソートは速いけど
    #   len(lst) 個分のメモリ memory を必要とします。
    memory = [None] * len(lst)  # <--- ポイント
    begin = 0
    end = len(lst) - 1
    _merge_sort(lst, memory, begin, end)


def _merge_sort(lst, memory, begin, end):
    if begin == end:
        return
    
    #
    # 1. ソート
    #
    mid = begin + (end - begin) // 2
    _merge_sort(lst, memory, begin, mid)
    _merge_sort(lst, memory, mid + 1, end)
    
    print_progress(lst, begin, end)
    
    #
    # 2. コピー
    #
    for i in range(begin, end + 1):
        memory[i] = lst[i]
    
    #
    # 3. 結合
    #
    left, index, right = begin, begin, mid + 1
    while (left <= mid) and (right <= end):
        if memory[left] <= memory[right]:
            lst[index] = memory[left]
            left = left + 1
        else:
            lst[index] = memory[right]
            right = right + 1
        index = index + 1
    
    while left <= mid:
        lst[index] = memory[left]
        left = left + 1
        index = index + 1
    
    while right <= mid:
        lst[index] = memory[right]
        right = right + 1
        index = index + 1


def print_progress(lst, *args):
    print(' '.join(f'{e:2d}' for e in lst))
    print(' '.join(progress(lst, *args)))
    print()


def progress(lst, begin, end):
    """ソートの途中経過を表示する.
    
    36 42 54 77 12 45 30 44
                << << >> >>
    
    36 42 54 77 12 30 44 45
    << << << << >> >> >> >>
    
    12 30 36 42 44 45 54 77
    
    << ... 左のソート済みのリスト
    >> ... 右のソート済みのリスト
    
    下段に行くとマージされる。
    """
    n = len(lst)
    mid = begin + (end - begin) // 2
    progress = []
    progress = progress + ['  '] * begin
    progress = progress + ['<<'] * (mid - begin + 1)
    progress = progress + ['>>'] * (end - mid)
    progress = progress + ['  '] * (n - end - 1)
    return progress

import random
lst = [random.randint(0, 99) for i in range(8)]
merge_sort(lst)
print(lst)

実行結果は以下のようになります。 ランダムな整数 8 個を作成して途中結果を表示します。 コメントは手書きで付け加えました。

$ python3 merge.py 
 0 48 64 34 90 11 51 50
 0 48 64 34 90 11 51 50
<< >>                  

 0 48 64 34 90 11 51 50
      << >>            

 0 48 34 64 90 11 51 50
<< << >> >>            

 0 34 48 64 90 11 51 50
            << >>      

 0 34 48 64 11 90 51 50
                  << >>

 0 34 48 64 11 90 50 51
            << << >> >>

 0 34 48 64 11 50 51 90 <-- << 左と >> 右でソートされたものを
<< << << << >> >> >> >>

 0 11 34 48 50 51 64 90 <-- マージする。