mirror of https://github.com/drowe67/phasenn.git
converted phasenn_test8 to tf.keras, this script has a custom loss function
parent
479a2580fb
commit
123213d16d
|
@ -8,35 +8,39 @@
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
from keras.layers import Dense
|
|
||||||
from keras import models,layers
|
|
||||||
from keras import initializers
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from keras import backend as K
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow import keras
|
||||||
|
from tensorflow.keras import Sequential
|
||||||
|
from tensorflow.keras.layers import Dense
|
||||||
|
|
||||||
# make tensorflow less verbose ....
|
# make tensorflow less verbose ....
|
||||||
import os
|
import os
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
# custom loss function
|
# custom loss function
|
||||||
def sparse_loss(y_true, y_pred):
|
def sparse_loss(y_true, y_pred):
|
||||||
mask = K.cast( K.not_equal(y_pred, 0), dtype='float32')
|
mask = tf.cast( tf.not_equal(y_pred, 0), dtype='float32')
|
||||||
n = K.sum(mask)
|
n = tf.reduce_sum(mask)
|
||||||
return K.sum(K.square((y_pred - y_true)*mask))/n
|
return tf.reduce_sum(tf.square((y_pred - y_true)*mask))/n
|
||||||
|
|
||||||
|
'''
|
||||||
# testing custom loss function
|
# testing custom loss function
|
||||||
x = layers.Input(shape=(None,))
|
x = layers.Input(shape=(None,))
|
||||||
y = layers.Input(shape=(None,))
|
y = layers.Input(shape=(None,))
|
||||||
loss_func = K.Function([x, y], [sparse_loss(x, y)])
|
loss_func = K.Function([x, y], [sparse_loss(x, y)])
|
||||||
assert loss_func([[[1,1,1]], [[0,2,0]]]) == np.array([1])
|
assert loss_func([[[1,1,1]], [[0,2,0]]]) == np.array([1])
|
||||||
assert loss_func([[[0,1,0]], [[0,2,0]]]) == np.array([1])
|
assert loss_func([[[0,1,0]], [[0,2,0]]]) == np.array([1])
|
||||||
|
'''
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
N = 80 # number of time domain samples in frame
|
N = 80 # number of time domain samples in frame
|
||||||
nb_samples = 400000
|
nb_samples = 400000
|
||||||
nb_batch = 32
|
nb_batch = 32
|
||||||
nb_epochs = 10
|
nb_epochs = 25
|
||||||
width = 256
|
width = 256
|
||||||
pairs = 2*width
|
pairs = 2*width
|
||||||
fo_min = 50
|
fo_min = 50
|
||||||
|
@ -80,14 +84,13 @@ for i in range(nb_samples):
|
||||||
filter_phase_rect[i,2*bin] = np.cos(filter_phase[i,bin])
|
filter_phase_rect[i,2*bin] = np.cos(filter_phase[i,bin])
|
||||||
filter_phase_rect[i,2*bin+1] = np.sin(filter_phase[i,bin])
|
filter_phase_rect[i,2*bin+1] = np.sin(filter_phase[i,bin])
|
||||||
|
|
||||||
model = models.Sequential()
|
model = Sequential()
|
||||||
model.add(layers.Dense(pairs, activation='relu', input_dim=width))
|
model.add(Dense(pairs, activation='relu', input_dim=width))
|
||||||
model.add(layers.Dense(4*pairs, activation='relu'))
|
model.add(Dense(4*pairs, activation='relu'))
|
||||||
model.add(layers.Dense(pairs))
|
model.add(Dense(pairs))
|
||||||
model.summary()
|
model.summary()
|
||||||
|
|
||||||
from keras import optimizers
|
sgd = keras.optimizers.SGD(lr=0.08, decay=1e-6, momentum=0.9, nesterov=True)
|
||||||
sgd = optimizers.SGD(lr=0.2, decay=1e-6, momentum=0.9, nesterov=True)
|
|
||||||
model.compile(loss=sparse_loss, optimizer=sgd)
|
model.compile(loss=sparse_loss, optimizer=sgd)
|
||||||
history = model.fit(filter_amp, filter_phase_rect, batch_size=nb_batch, epochs=nb_epochs)
|
history = model.fit(filter_amp, filter_phase_rect, batch_size=nb_batch, epochs=nb_epochs)
|
||||||
|
|
||||||
|
@ -135,7 +138,7 @@ if plot_en:
|
||||||
|
|
||||||
plt.figure(2)
|
plt.figure(2)
|
||||||
plt.subplot(211)
|
plt.subplot(211)
|
||||||
plt.hist(err_angle*180/np.pi, bins=20)
|
plt.hist(err_angle*180/np.pi, bins=50)
|
||||||
plt.subplot(212)
|
plt.subplot(212)
|
||||||
plt.hist(Wo*(Fs/2)/np.pi, bins=20)
|
plt.hist(Wo*(Fs/2)/np.pi, bins=20)
|
||||||
plt.title('phase angle error (deg) and fo (Hz)')
|
plt.title('phase angle error (deg) and fo (Hz)')
|
||||||
|
|
|
@ -12,8 +12,8 @@ import numpy as np
|
||||||
import sys
|
import sys
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from tensorflow import keras
|
|
||||||
|
|
||||||
|
from tensorflow import keras
|
||||||
from tensorflow.keras import Sequential
|
from tensorflow.keras import Sequential
|
||||||
from tensorflow.keras.layers import Dense
|
from tensorflow.keras.layers import Dense
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue