Tensorflowでバッチをトレーニングする


11

現在、大きなcsvファイル(> 70GBで6,000万行以上)でモデルをトレーニングしようとしています。そのために、tf.contrib.learn.read_batch_examplesを使用しています。この関数が実際にデータを読み取る方法を理解するのに苦労しています。たとえば50.000のバッチサイズを使用している場合、ファイルの最初の50.000行を読み取りますか?ファイル全体(1エポック)をループする場合は、num_rows / batch_size = 1.200ステップ数をestimator.fitメソッドに使用する必要がありますか?

現在使用している入力関数は次のとおりです。

def input_fn(file_names, batch_size):
    # Read csv files and create examples dict
    examples_dict = read_csv_examples(file_names, batch_size)

    # Continuous features
    feature_cols = {k: tf.string_to_number(examples_dict[k],
                                           out_type=tf.float32) for k in CONTINUOUS_COLUMNS}

    # Categorical features
    feature_cols.update({
                            k: tf.SparseTensor(
                                indices=[[i, 0] for i in range(examples_dict[k].get_shape()[0])],
                                values=examples_dict[k],
                                shape=[int(examples_dict[k].get_shape()[0]), 1])
                            for k in CATEGORICAL_COLUMNS})

    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32)

    return feature_cols, label


def read_csv_examples(file_names, batch_size):
    def parse_fn(record):
        record_defaults = [tf.constant([''], dtype=tf.string)] * len(COLUMNS)

        return tf.decode_csv(record, record_defaults)

    examples_op = tf.contrib.learn.read_batch_examples(
        file_names,
        batch_size=batch_size,
        queue_capacity=batch_size*2.5,
        reader=tf.TextLineReader,
        parse_fn=parse_fn,
        #read_batch_size= batch_size,
        #randomize_input=True,
        num_threads=8
    )

    # Important: convert examples to dict for ease of use in `input_fn`
    # Map each header to its respective column (COLUMNS order
    # matters!
    examples_dict_op = {}
    for i, header in enumerate(COLUMNS):
        examples_dict_op[header] = examples_op[:, i]

    return examples_dict_op

モデルのトレーニングに使用するコードは次のとおりです。

def train_and_eval():
"""Train and evaluate the model."""

m = build_estimator(model_dir)
m.fit(input_fn=lambda: input_fn(train_file_name, batch_size), steps=steps)

同じinput_fnで再度fit関数を呼び出すとどうなりますか?それは再びファイルの先頭から始まりますか、それとも前回停止した行を覚えていますか?


私が見つかりました。medium.com/@ilblackdragon/...が input_fn tensorflow以内にバッチ処理に便利
fistynuts

ヤウはすでにそれをチェックしましたか?stackoverflow.com/questions/37091899/...
Frankstr

回答:


1

まだ答えが出ていないので、少なくともなんとか役立つ答えを出してみたいと思います。定数定義を含めると、提供されたコードを理解するのに少し役立ちます。

一般的に言えば、バッチはレコードまたはアイテムのn倍を使用します。アイテムの定義方法は、問題によって異なります。テンソルフローでは、バッチはテンソルの最初の次元でエンコードされます。あなたのcsvファイルの場合、それは行ごと(reader=tf.TextLineReader)かもしれません。列ごとに学習することもできますが、これがコードで起こっているとは思いません。データセット全体(= 1エポック)でトレーニングしたい場合は、を使用してトレーニングできますnumBatches=numItems/batchSize

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.