エンジニアのソフトウェア的愛情

または私は如何にして心配するのを止めてプログラムを・愛する・ようになったか

diffをつくる(2)

昨日のコードでつくったグラフをもとに、SES(Shortest Edit Script)を見つけるコードを追加しました。体裁が変わったように見えますが、makeGraphまでは基本的に昨日と同じコードです。
今日のキモはfindSESのコード。もっとも少ない手順で一方の文字列からもう一方の文字列への変換を、終点から始点へ向かって調べていきます。

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <iterator>
#include <utility>
#include <cstdlib>

enum EditType
{
    DELETE,
    COMMON,
    ADD
};

typedef std::vector<std::vector<int> >               Grid;
typedef std::pair<EditType, std::string::value_type> Item;

std::ostream& operator << (std::ostream& out, const Item& item)
{
    switch(item.first)
    {
    case ADD:    out << '+'; break;
    case COMMON: out << ' '; break;
    case DELETE: out << '-'; break;
    default:     out << '?'; break;
    }
    return out << ' ' << item.second;
}

void initializeGrid(Grid& grid, int rowSize, int colSize)
{
    grid.resize(rowSize);
    for(std::size_t i = 0; i < grid.size(); ++i)
    {
        grid[i].resize(colSize);
    }
}

void makeGraph(Grid& grid, const std::string& str_a, const std::string& str_b)
{
    for(std::size_t row = 0; row < grid.size(); ++row)
    {
        grid[row][0] = row;
    }

    for(std::size_t col = 0; col < grid[0].size(); ++col)
    {
        grid[0][col] = col;
    }

    for(std::size_t row = 1; row < grid.size(); ++row)
    {
        for(std::size_t col = 1; col < grid[row].size(); ++col)
        {
            int d = std::min(grid[row - 1][col] + 1, grid[row][col - 1] + 1);
            if(str_a[row - 1] == str_b[col - 1])
            {
                d = std::min(d, grid[row - 1][col - 1]);
            }
            grid[row][col] = d;
        }
    }
}

std::vector<Item> findSES(Grid& grid, const std::string& str_a, const std::string& str_b)
{
    std::vector<Item> ses;
    int               row = str_a.size();
    int               col = str_b.size();

    while((row > 0) && (col > 0))
    {
        int a = grid[row][col - 1];
        int d = grid[row - 1][col];
        int c = grid[row - 1][col - 1];
        if(d < a)
        {
            if((str_a[row - 1] == str_b[col - 1]) && (c < d))
            {
                ses.push_back(Item(COMMON, str_a[row - 1]));
                --row;
                --col;
            }
            else
            {
                ses.push_back(Item(DELETE, str_a[row - 1]));
                --row;
            }
        }
        else
        {
            if((str_a[row - 1] == str_b[col - 1]) && (c < a))
            {
                ses.push_back(Item(COMMON, str_a[row - 1]));
                --row;
                --col;
            }
            else
            {
                ses.push_back(Item(ADD, str_b[col - 1]));
                --col;
            }
        }
    }

    while(col > 0)
    {
        ses.push_back(Item(ADD, str_b[col - 1]));
        --col;
    }

    while(row > 0)
    {
        ses.push_back(Item(DELETE, str_a[row - 1]));
        --row;
    }

    return std::vector<Item>(ses.rbegin(), ses.rend());
}

void diff(const std::string& str_a, const std::string& str_b)
{
    Grid grid;

    initializeGrid(grid, str_a.size() + 1, str_b.size() + 1);

    makeGraph(grid, str_a, str_b);

    std::vector<Item> ses = findSES(grid, str_a, str_b);

    std::copy(ses.begin(), ses.end(), std::ostream_iterator<Item>(std::cout, "\n"));
}

int main(int argc, char* argv[])
{
    if(argc == 3)
    {
        diff(argv[1], argv[2]);
    }
    else
    {
        std::cout << argv[0] << " <str_a> <str_b>\n";
    }

    return 0;
}

実行してみます。

$ ./my_diff2 string strength
   s
   t
   r
 - i
 + e
   n
   g
 + t
 + h

一応diffがとれました。ベタですが。ここから理解を深めていく予定。


次はO(ND)を予定。つづく。