DataFrameのapply関数を用いて、shift()やdiff()を適用できる汎用的な関数を作りたい

Python
スポンサーリンク

こんにちは、@yshr10icです。

1, 2年くらい前からSIGNATEやProbSpaceなど多くのコンペに参加するようになりました。

今まで書いてきたコードもだいぶ溜まってきたので、主にテーブルコンペを対象にパイプラインを作成しています。

そんな中で時系列データに対して、shiftやdiffを取るときに冗長なコードとなってしまっていたので、DataFrameのapply関数を用いて、汎用的な関数の作り方について調べたので、ブログにまとめたいと思います。

スポンサーリンク

テストデータ準備

まずは実験対象となるテストデータを用意します。

import pandas as pd

def create_dataframe():
    return pd.DataFrame({
        'category': ['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'],
        'val': [1, 3, 5, 7, 9, 2, 4, 6, 8, 10]
    })


df = create_dataframe()
print(df)

汎用化前

次にこのcategory列をキーにして、val列のshiftやdiffを取りたいときには、以下のような関数をそれぞれ用意していました。

def shift_aggregation(input_df, group_key, group_values, nums):
    dfs = []
    for num in nums:
        _df = input_df.groupby(group_key)[group_values].shift(num)
        _df.columns = [f'shift={num}_{col}_groupby_{group_key}' for col in group_values]
    dfs.append(_df)

    return pd.concat(dfs, axis=1)


def diff_aggregation(input_df, group_key, group_values, nums):
    dfs = []
    for num in nums:
        _df = input_df.groupby(group_key)[group_values].diff(num)
        _df.columns = [f'diff={num}_{col}_groupby_{group_key}' for col in group_values]
    dfs.append(_df)

    return pd.concat(dfs, axis=1)


shift_df = shift_aggregation(df, 'category', ['val'], [1, 2])
diff_df = diff_aggregation(df, 'category', ['val'], [1, 2])
output_df = pd.concat([df, shift_df, diff_df], axis=1)

shift_aggregationメソッドとdiff_aggregationメソッドを見てみると、ほとんどの部分が共通していることが分かります。これをDataFrameのapply関数を用いて書き直します。

汎用化後

汎用化した関数を以下のように用意します。

def group_aggregation(input_df, group_key, group_values, nums, func, name):
    dfs = []
    for num in nums:
        _df = input_df.groupby(group_key)[group_values].apply(func, num=num)
        _df.columns = [f'{name}={num}_{col}_groupby_{group_key}' for col in group_values]
        dfs.append(_df)

    return pd.concat(dfs, axis=1)

group_aggregationメソッドには、引数を2つ追加しています。funcとnameです。funcはshiftやdiffをlambda関数として渡します。nameはカラム名を付けるために使用します。

次にshiftとdiffのlambda関数を用意します。ここでは引数numに何時点前のレコードを参照するかを渡せるようにします。

shift_func = lambda x, num: x.shift(num)
diff_func = lambda x, num: x.diff(num)

後はそれぞれのlambda関数をgroup_aggregationメソッドに渡すだけです。

shift_df = group_aggregation(df, 'category', ['val'], [1, 2], shift_func, 'shift')
diff_df = group_aggregation(df, 'category', ['val'], [1, 2], diff_func, 'diff')
output_df = pd.concat([df, shift_df, diff_df], axis=1)
print(output_df)

最後に

いかがだったでしょうか?基本的な内容かもしれませんが、lambda関数に引数を渡す方法や、共通的なコードを汎用的な関数にまとめる方法について知ることができました。これから引き続きコンペに向けたパイプラインの作成をしていきたいと思います。

もしもっと良い書き方があれば教えていただけると嬉しいです。

タイトルとURLをコピーしました