Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when running imported/restored model that uses feedable iterator #45

Open
fuhailin opened this issue Apr 12, 2022 · 0 comments
Open

Comments

@fuhailin
Copy link
Contributor

fuhailin commented Apr 12, 2022

I got a situation where I trained a model and saved its checkpoint files, then I need to restore the graph from the meta file and feed a new data iterator to keep training, so i find a issue talking about that, then i write some code to demo my situation.

Current behavior

When i use ParquetDataset to feed, i can't restore the meta file, and got the following error:

Traceback (most recent call last):
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test/io/restore_hb.py", line 223, in <module>
    resume_training(another_train_dataset, another_test_dataset)
  File "test/io/restore_hb.py", line 132, in resume_training
    saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
    **kwargs)[0]
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_flat_map__create_dataset_10' because a different function with the same name already exists.

I guess that error not belongs to a bug for HybridBackend, because i also try the TFRecordDataset and get a similar error:

Traceback (most recent call last):
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test/io/restore_pb.py", line 225, in <module>
    restore_feed()
  File "test/io/restore_pb.py", line 220, in restore_feed
    resume_training(another_train_dataset, another_test_dataset)
  File "test/io/restore_pb.py", line 155, in resume_training
    saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1697, in import_meta_graph
    **kwargs)[0]
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1721, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/meta_graph.py", line 809, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/pai/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: Cannot add function '__inference_Dataset_map__parse_function_55' because a different function with the same name already exists.

But that process works for from_tensor_slices and CsvDataset, i'm just curious and want to know how to restore and feed a new dataset iterator.

Expected behavior

When i use ParquetDataset in traing, i can restore the checkpoint and feed a new ParquetDataset iterator

System information

  • GPU model and memory: 16G for Tesla T4
  • OS Platform: Ubuntu 18.04.5 LTS (Bionic Beaver)
  • Docker version: 20.10.14
  • GCC/CUDA/cuDNN version: gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04),
  • Python/conda version: Python 3.6.12
  • TensorFlow/PyTorch version: tensorflow 1.15.5+deeprec2201

Code to reproduce

training and restore use ParquetDataset to feed that doesn't work

# Tensorflow 1.15
# https://github.com/tensorflow/tensorflow/issues/11679#
#
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from hybridbackend.tensorflow.data import DataFrame
from hybridbackend.tensorflow.data import ParquetDataset
from tensorflow.python.data.ops import dataset_ops

new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_parquet('train.parquet')

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_parquet('test.parquet')


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def train(train_dataset, test_dataset):
  """
    Create graph with an Dataset and Iterator and save the model.

    There is some op that is applied to the data from the iterator.
    """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element['test1']))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_hb/fufu')

def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_hb/fufu.meta')
    saver.restore(sess, 'checkpoints_hb/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_hb', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_hb/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_hb'):
    shutil.rmtree('checkpoints_hb')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = ParquetDataset('train.parquet',
                                 batch_size=1,
                                 fields=[DataFrame.Field('test1', tf.float32),
                                         DataFrame.Field('test2', tf.float32)])
  test_dataset = ParquetDataset('test.parquet',
                                batch_size=1,
                                fields=[DataFrame.Field('test1', tf.float32),
                                        DataFrame.Field('test2', tf.float32)])
  # test_dataset = tf.data.Dataset.from_tensor_slices(
  # tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = ParquetDataset(
      'train.parquet',
      batch_size=1,
      fields=[DataFrame.Field('test1', tf.float32),
              DataFrame.Field('test2', tf.float32)])
  another_test_dataset = ParquetDataset(
      'test.parquet', batch_size=1, fields=[DataFrame.Field('test1', tf.float32),
                                            DataFrame.Field('test2', tf.float32)])

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

It works neither for TFRecordDataset.

import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.ops import dataset_ops


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


# Define features
feature_description = {
    'test1': tf.io.FixedLenFeature([], dtype=tf.float32),
    'test2': tf.io.FixedLenFeature([], dtype=tf.float32)
}


def _parse_function(example_proto):
  return tf.io.parse_example(example_proto, feature_description)


def write_pb(df, file):
  # Write TFrecord file
  with tf.io.TFRecordWriter(file) as writer:
    for index, row in df.iterrows():
      print(row['test1'], row['test2'])
      # Create the Example
      example = tf.train.Example(features=tf.train.Features(
          feature={
              'test1': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test1']])),
              'test2': tf.train.Feature(float_list=tf.train.FloatList(value=[row['test2']]))
          }))
      writer.write(example.SerializeToString())


new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
write_pb(train_df, 'train.tfrecord')

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
write_pb(test_df, 'test.tfrecord')


def train(train_dataset, test_dataset):
  """
  Create graph with an Dataset and Iterator and save the model.

  There is some op that is applied to the data from the iterator.
  """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element['test1']))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_pb/fufu')


def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_pb/fufu.meta')
    saver.restore(sess, 'checkpoints_pb/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_pb', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_pb/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_pb'):
    shutil.rmtree('checkpoints_pb')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
  test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = tf.data.TFRecordDataset(['train.tfrecord']).batch(1).map(_parse_function)
  another_test_dataset = tf.data.TFRecordDataset(['test.tfrecord']).batch(1).map(_parse_function)

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

But works for CsvDataset

import tensorflow as tf
import numpy as np
import pandas as pd
import os
import shutil
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops

new_dtypes = {"test1": np.float32, "test2": np.float32}

train_df = pd.DataFrame(np.random.randint(0, 100, (5, 2)), columns=['test1', 'test2'])
train_df = train_df.astype(new_dtypes)
train_df.to_csv('train.csv', index=False)

test_df = pd.DataFrame(np.random.randint(0, 100, (2, 2)), columns=['test1', 'test2'])
test_df = test_df.astype(new_dtypes)
test_df.to_csv('test.csv', index=False)


def make_initializable_iterator(ds):
  if hasattr(dataset_ops, 'make_initializable_iterator'):
    return dataset_ops.make_initializable_iterator(ds)
  return ds.make_initializable_iterator()


def make_one_shot_iterator(ds):
  if hasattr(dataset_ops, 'make_one_shot_iterator'):
    return dataset_ops.make_one_shot_iterator(ds)
  return ds.make_one_shot_iterator()


def train(train_dataset, test_dataset):
  """
    Create graph with an Dataset and Iterator and save the model.

    There is some op that is applied to the data from the iterator.
    """
  iterator_handle = tf.placeholder(tf.string, shape=[])
  tf.add_to_collection('iterator_handle', iterator_handle)

  iterator = tf.data.Iterator.from_string_handle(iterator_handle, dataset_ops.get_legacy_output_types(train_dataset),
                                                 dataset_ops.get_legacy_output_shapes(train_dataset),
                                                 dataset_ops.get_legacy_output_classes(train_dataset))
  train_iter = make_initializable_iterator(train_dataset)
  test_iter = make_initializable_iterator(test_dataset)
  element = iterator.get_next()

  v = tf.get_variable(name='v', initializer=tf.zeros(shape=(1, 2)))

  # to use when saving summaries
  global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
  increament_global_step = tf.assign(global_step, global_step + 1)
  global_step = global_step + 1
  tf.add_to_collection('increament_global_step', increament_global_step)

  some_op = tf.assign(v, v + tf.abs(element))
  tf.add_to_collection('some_op', tf.reduce_sum(some_op))

  tf.summary.scalar('v_sum', tf.reduce_sum(v))
  tf.summary.scalar('some_op', tf.reduce_mean(some_op))
  merged_summary = tf.summary.merge_all()
  tf.add_to_collection('merged_summary', merged_summary)

  writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_csv/fufu')


def resume_training(train_dataset, test_dataset):
  """Restore the model from file and pass some new data through it
     for further training """
  with tf.Session() as sess:
    saver = tf.train.import_meta_graph('checkpoints_csv/fufu.meta')
    saver.restore(sess, 'checkpoints_csv/fufu')
    iterator_handle = tf.get_collection('iterator_handle')[0]
    some_op = tf.get_collection('some_op')[0]
    increament_global_step = tf.get_collection('increament_global_step')[0]
    merged_summary = tf.get_collection('merged_summary')[0]

    writer = tf.summary.FileWriter('checkpoints_csv', graph=tf.get_default_graph())

    # Make new iterators and handles
    train_iter = make_initializable_iterator(train_dataset)
    test_iter = make_initializable_iterator(test_dataset)

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Further training the model using new datasets (which may be different from original ones)
    print("Resume training ...")

    train_handle = sess.run(train_iter.string_handle())
    test_handle = sess.run(test_iter.string_handle())

    # Run data iterator initialisation
    sess.run(train_iter.initializer)
    sess.run(test_iter.initializer)

    # "Training"
    print("Training")
    while True:
      try:
        [op, summary_values, g_step] = sess.run([some_op, merged_summary, increament_global_step],
                                                feed_dict={iterator_handle: train_handle})
        writer.add_summary(summary_values, global_step=g_step)
        print(op)
      except tf.errors.OutOfRangeError:
        break

    # "Test evaluation"
    print("Testing")
    while True:
      try:
        print(sess.run(some_op, feed_dict={iterator_handle: test_handle}))
      except tf.errors.OutOfRangeError:
        break

    saver.save(sess, 'checkpoints_csv/fufu')


def train_feed():
  # delete existing saved models and summary files
  if os.path.exists('checkpoints_csv'):
    shutil.rmtree('checkpoints_csv')
  # train_dataset = tf.data.Dataset.from_tensor_slices(
  #     tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
  train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
  test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)
  # test_dataset = tf.data.Dataset.from_tensor_slices(
  # tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))

  train(train_dataset, test_dataset)


def restore_feed():
  # Load and fine-tune the saved model using new data
  another_train_dataset = readers.CsvDataset("train.csv", record_defaults=[tf.float32, tf.float32], header=True)
  another_test_dataset = readers.CsvDataset("test.csv", record_defaults=[tf.float32, tf.float32], header=True)

  resume_training(another_train_dataset, another_test_dataset)


if __name__ == '__main__':
  train_feed()
  restore_feed()

Willing to contribute

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant