import keras
from keras.optimizers import SGD
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout
from keras.utils import np_utils
from keras.utils.np_utils import to_categorical
from keras.callbacks import Callback
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image
import os
import os.path
import sys

optpara=0.05
# オプティマイザ
opt = SGD(lr=optpara)
#opt = Adam(lr=optpara, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
epochs=100000

image_list = []
label_list = []
image_list2 = []
label_list2 = []

a="triangle3"
b="circle2"
a2="t2"
b2="c2"

for dir in os.listdir("shape3"):
    if dir == "/.DS_Store":
        continue


    dir1 = "shape3/" + dir 
    label = 0

    if dir == a:
        label = 0
    elif dir == b:
        label = 1
    else:
        continue

    for file in os.listdir(dir1):
        if file != ".DS_Store":
            label_list.append(label)
            filepath = dir1 + "/" + file
            image = np.array(Image.open(filepath).convert("L").resize((28, 28)))
            image = image.reshape(1, 784).astype("float32")[0]
            image_list.append(image / 255.)

for dir in os.listdir("shape4"):
    if dir == "/.DS_Store":
        continue

    dir1 = "shape/" + dir 
    label2 = 0

    if dir == "t2":
        label2 = 0
    elif dir == "c2":
        label2 = 1
    else:
        continue

    for file in os.listdir(dir1):
        if file != ".DS_Store":
            label_list2.append(label2)
            filepath = dir1 + "/" + file
            image2 = np.array(Image.open(filepath).convert("L").resize((28, 28)))
            image2 = image2.reshape(1, 784).astype("float32")[0]
            image_list2.append(image2 / 255.)

image_list = np.array(image_list)
image_list2 = np.array(image_list2)

Y = to_categorical(label_list)
Y2 = to_categorical(label_list2)

rand=random.randint(0,1000000)
np.random.seed(rand)
np.random.shuffle(image_list)
np.random.seed(rand)
np.random.shuffle(Y)

rand2=random.randint(0,1000000)
np.random.seed(rand2)
np.random.shuffle(image_list2)
np.random.seed(rand2)
np.random.shuffle(Y2)

dust,image_list00,image_list01,image_list02,image_list04=np.split(image_list, [10200,11500,12500,12500])
image_list20,image_list21,image_list22=np.split(image_list2, [0,1300])
image_list03=np.concatenate([image_list00, image_list20])
dust,Y00,Y01,Y02,Y04=np.split(Y, [10200,11500,12500,12500])
Y20,Y21,Y22=np.split(Y2, [0,1300])
Y03=np.concatenate([Y00, Y20])

image_list02=np.concatenate([image_list02, image_list22])
Y02=np.concatenate([Y02, Y22])

model = Sequential()
model.add(Dense(640, input_shape=(784,)))
model.add(Activation("relu"))
model.add(Dropout(0.2))
model.add(Dense(640, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(640, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(2))
model.add(Activation("softmax"))

model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])

es_cb = EarlyStopping(monitor='val_acc',min_delta=0, patience=100, verbose=1, mode='auto')
modelCheckpoint = ModelCheckpoint(filepath = 'model1.h5',
                                  monitor='val_acc',
                                  verbose=1,
                                  save_best_only=True,
                                  save_weights_only=False,
                                  mode='auto',
                                  period=1)

result11=model.fit(image_list03, Y03, epochs=epochs, batch_size=64,validation_split=0.5,callbacks=[ modelCheckpoint,es_cb])

val_acc1 = result11.history['val_acc']
val_loss1 = result11.history['val_loss']

model2 = Sequential()
model2.add(Dense(640, input_shape=(784,)))
model2.add(Activation("relu"))
model2.add(Dropout(0.2))
model2.add(Dense(640, activation='relu'))
model2.add(Dropout(0.2))
model2.add(Dense(640, activation='relu'))
model2.add(Dropout(0.2))
model2.add(Dense(2))
model2.add(Activation("softmax"))

if os.path.exists("./model1.h5"):
    model2.load_weights("model1.h5", by_name=False)
    c=1
    total = 0.
    ok_count = 0.
    a_split = np.split(image_list02, 100)
    b_split = np.split(Y02, 100)
    d=np.array([[1,0]])

    for i in range(100):
        result=model2.predict_classes(a_split[i])
        if (b_split[i]==d).all():
            c=0
        if (b_split[i]!=d).all():
            c=1
        total += 1.
        if c==result[0]:
            ok_count += 1.
        print(b_split[i])
        print("label:", c, "result:", result[0])


    print("seikai: ", ok_count / total * 100, "%")
    with open('acc11.txt', mode='w') as f:
        f.write(str(ok_count / total * 100))

if os.path.exists("./model1.h5"):
    model2.load_weights("model1.h5", by_name=False)
    c=1
    total = 0.
    ok_count = 0.
    a2_split = np.split(image_list04, 100)
    b2_split = np.split(Y04, 100)
    d=np.array([[1,0]])

    for i in range(100):
        result=model2.predict_classes(a2_split[i])
        if (b2_split[i]==d).all():
            c=0
        if (b2_split[i]!=d).all():
            c=1
        total += 1.
        if c==result[0]:
            ok_count += 1.
        print(b2_split[i])
        print("label:", c, "result:", result[0])

    print("seikai: ", ok_count / total * 100, "%")
    with open('acc12.txt', mode='w') as f:
        f.write(str(ok_count / total * 100))


np.savetxt('val_acc1.txt', val_acc1, fmt='%.5f')
np.savetxt('val_loss1.txt', val_loss1, fmt='%.5f')