diff --git a/examples/autokeras/autokeras_imdb_example.py b/examples/autokeras/autokeras_imdb_example.py new file mode 100644 index 00000000..bfcf90b7 --- /dev/null +++ b/examples/autokeras/autokeras_imdb_example.py @@ -0,0 +1,53 @@ +import tensorflow as tf +import autokeras as ak +import numpy as np +from tensorflow import keras +from trains import Task + +task = Task.init(project_name="autokeras", task_name="autokeras imdb example with scalars") + +def imdb_raw(): + max_features = 20000 + index_offset = 3 # word index offset + + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data( + num_words=max_features, + index_from=index_offset) + x_train = x_train + y_train = y_train.reshape(-1, 1) + x_test = x_test + y_test = y_test.reshape(-1, 1) + + word_to_id = tf.keras.datasets.imdb.get_word_index() + word_to_id = {k: (v + index_offset) for k, v in word_to_id.items()} + word_to_id[""] = 0 + word_to_id[""] = 1 + word_to_id[""] = 2 + + id_to_word = {value: key for key, value in word_to_id.items()} + x_train = list(map(lambda sentence: ' '.join( + id_to_word[i] for i in sentence), x_train)) + x_test = list(map(lambda sentence: ' '.join( + id_to_word[i] for i in sentence), x_test)) + x_train = np.array(x_train, dtype=np.str) + x_test = np.array(x_test, dtype=np.str) + return (x_train, y_train), (x_test, y_test) + + +# Prepare the data. +(x_train, y_train), (x_test, y_test) = imdb_raw() +print(x_train.shape) # (25000,) +print(y_train.shape) # (25000, 1) +print(x_train[0][:50]) # this film was just brilliant casting + +# Initialize the TextClassifier +clf = ak.TextClassifier(max_trials=3) + +tensorboard_callback_train = keras.callbacks.TensorBoard(log_dir='log') +tensorboard_callback_test = keras.callbacks.TensorBoard(log_dir='log') + +# Search for the best model. +clf.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback_train]) +clf.fit(x_test, y_test, epochs=2, callbacks=[tensorboard_callback_test]) +# Evaluate on the testing data. +print('Accuracy: {accuracy}'.format(accuracy=clf.evaluate(x_test, y_test))) \ No newline at end of file