import keras
from keras.optimizers import SGD
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout
from keras.utils import np_utils
from keras.utils.np_utils import to_categorical
from keras.callbacks import Callback
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import os
import os.path
import sys

#opt = SGD(lr=optpara)
opt = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
epochs=10000

class CustomModelCheckpoint(Callback):
    def __init__(self, filepath, thresholds, inverse=False):
        super(CustomModelCheckpoint, self).__init__()
        self.filepath = filepath
        self.thresholds = thresholds
        self.inverse = inverse

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        filepath = self.filepath.format(epoch=epoch, **logs)
        save = True
        for k,v in self.thresholds.items():
            if self.inverse:
                if (k in logs and logs[k] > v) or np.isnan(logs[k]):
                    save = False
                    break
            else:
                if (k in logs and logs[k] < v) or np.isnan(logs[k]):
                    save = False
                    break
        if save:
            self.model.save(filepath, overwrite=False)

image_list = []
label_list = []

a="triangle"
b="circle"

for dir in os.listdir("fig"):
    if dir == "/.DS_Store":
        continue
    dir1 = "fig/" + dir 
    label = 0
    if dir == a:
        label = 0
    elif dir == b:
        label = 1
    else:
        continue
    for file in os.listdir(dir1):
        if file != ".DS_Store":
            label_list.append(label)
            filepath = dir1 + "/" + file
            image = np.array(Image.open(filepath).convert("L").resize((28, 28)))
            image = image.reshape(1, 784).astype("float32")[0] 
            image_list.append(image / 255.)

image_list = np.array(image_list)

Y = to_categorical(label_list)

rand=random.randint(0,1000000)
np.random.seed(rand)
np.random.shuffle(image_list)
np.random.seed(rand)
np.random.shuffle(Y)


dust,image_list0,image_list1,image_list2=np.split(image_list, [4200,6300,8400])
image_list3=np.concatenate([image_list0, image_list1])
dust,Y0,Y1,Y2=np.split(Y, [4200,6300,8400])
Y3=np.concatenate([Y0, Y1])

model = Sequential()
model.add(Dense(640, input_shape=(784,)))
model.add(Activation("relu"))
model.add(Dropout(0.2))
model.add(Dense(640, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(640, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(2))
model.add(Activation("softmax"))

model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])

checkpoint = CustomModelCheckpoint(
    filepath="./model1.h5",
    thresholds={
        'val_acc': 0.9
    })

result=model.fit(image_list3, Y3, epochs=epochs, batch_size=64,validation_split=0.5,callbacks=[checkpoint])

val_acc = result.history['val_acc']

model2 = Sequential()
model2.add(Dense(640, input_shape=(784,)))
model2.add(Activation("relu"))
model2.add(Dropout(0.2))
model2.add(Dense(640, activation='relu'))
model2.add(Dropout(0.2))
model2.add(Dense(640, activation='relu'))
model2.add(Dropout(0.2))
model2.add(Dense(2))
model2.add(Activation("softmax"))

if os.path.exists("./model1.h5"):
    model2.load_weights("model1.h5", by_name=False)
    c=1
    total = 0.
    ok_count = 0.
    a_split = np.split(image_list2, 4200)
    b_split = np.split(Y2, 4200)
    d=np.array([[1,0]])

    for i in range(4200):
        result2=model2.predict_classes(a_split[i])
        if (b_split[i]==d).all():
            c=0
        if (b_split[i]!=d).all():
            c=1
        total += 1.
        if c==result2[0]:
            ok_count += 1.