phasenn/rateK_train.py

160 lines
5.0 KiB
Python
Executable File

#!/usr/bin/python3
# rateK_train.py
#
# David Rowe Dec 2019
#
# Experiments in interpolating rate K vectors using NN's and other
# techniques.
'''
Usage:
$ c2sim ~/Downloads/train_8k.sw --rateK --rateKout ~/phasenn/rateK.f32
$ ./rateK_train.py rateK.f32 --dec 4 --frame 30 --epochs 25
'''
import numpy as np
import sys
import matplotlib.pyplot as plt
from scipy import signal
import codec2_model
import argparse
import os
from keras.layers import Input, Dense, Concatenate
from keras import models,layers
from keras import initializers
from keras import backend as K
# less verbose tensorflow ....
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# constants
nb_batch = 32
newamp1_K = 20
nb_plots = 6
N = 80
def list_str(values):
return values.split(',')
parser = argparse.ArgumentParser(description='Train a NN to interpolate rate K vectors')
parser.add_argument('featurefile', help='f32 file of newamp1 rate K vectors')
parser.add_argument('--dec', type=int, default=4, help='decimation rate')
parser.add_argument('--frame', type=int, default="30", help='Frames to view')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
args = parser.parse_args()
dec = args.dec
# read in rate K vectors
features = np.fromfile(args.featurefile, dtype='float32')
nb_features = newamp1_K
nb_samples = int(len(features)/nb_features)
print("nb_samples: %d" % (nb_samples))
rateK = np.reshape(features, (nb_samples, nb_features))/20
print(rateK.shape)
# set up training data
nb_vecs = int(nb_samples/dec)
inputs = np.zeros((nb_vecs, 2*newamp1_K))
outputs = np.zeros((nb_vecs, 3*newamp1_K))
outputs_lin = np.zeros((nb_vecs, 3*newamp1_K))
outputs_linpf = np.zeros((nb_vecs, 3*newamp1_K))
nv = 0
for i in range(0,nb_samples-dec,dec):
inputs[nv,:newamp1_K] = rateK[i,:]
inputs[nv,newamp1_K:] = rateK[i+dec,:]
for j in range(dec-1):
st = j*newamp1_K
outputs[nv,st:st+newamp1_K] = rateK[i+1+j,:]
# linear interpolation for reference
c = 1.0/dec; inc = 1.0/dec;
for j in range(dec-1):
st = j*newamp1_K
outputs_lin[nv,st:st+newamp1_K] = (1-c)*inputs[nv,:newamp1_K] + c*inputs[nv,newamp1_K:]
c += inc
# linear interpolation with per frame selection of c
for j in range(dec-1):
A = inputs[nv,:newamp1_K]; B = inputs[nv,newamp1_K:];
T = rateK[i+1+j,:]
c = -np.dot((B-T),(A-B))/np.dot((A-B),(A-B))
st = j*newamp1_K
outputs_linpf[nv,st:st+newamp1_K] = c*A + (1-c)*B
nv += 1
print(inputs.shape, outputs.shape)
nn = 1
if nn:
# our model
model = models.Sequential()
model.add(layers.Dense(3*newamp1_K, activation='tanh', input_dim=2*newamp1_K))
model.add(layers.Dense(3*newamp1_K, activation='tanh'))
model.add(layers.Dense(3*newamp1_K))
model.summary()
# fit the model
from keras import optimizers
#sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mse', optimizer="adam")
history = model.fit(inputs, outputs, batch_size=nb_batch, epochs=args.epochs, validation_split=0.1)
# test the model on the training data
outputs_nnest = model.predict(inputs)
plt.figure(1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['train', 'valid'], loc='upper right')
plt.title('model loss')
plt.xlabel('epoch')
plt.show(block=False)
# plot results over all frames
var_lin = np.var(20*outputs-20*outputs_lin)
var_linpf = np.var(20*outputs-20*outputs_linpf)
var_nnest = np.var(20*outputs-20*outputs_nnest)
print("var_lin: %3.2f var_linpf: %3.2f var_nnest: %3.2f" % (var_lin, var_linpf, var_nnest))
# plot results for a few frames
nb_plots = dec+1; nb_plotsy = 1; nb_plotsx = nb_plots
frame = int(args.frame/dec)
plt.figure(2)
loop = True
print("Press key to advance, mouse click on last figure to finish....")
while loop:
plt.title('rate K Amplitude Spectra')
for d in range(dec+1):
plt.subplot(1, nb_plots, d+1)
if d == 0:
plt.plot(inputs[frame,:newamp1_K],'g')
elif d == dec:
plt.plot(inputs[frame,newamp1_K:],'g')
else:
st = (d-1)*newamp1_K
plt.plot(outputs[frame,st:st+newamp1_K],'g')
plt.plot(outputs_lin[frame,st:st+newamp1_K],'b')
if nn:
plt.plot(outputs_nnest[frame,st:st+newamp1_K],'r')
else:
plt.plot(outputs_linpf[frame,st:st+newamp1_K],'r')
plt.ylim((-1,4))
var_lin = np.var(20*outputs[frame,:]-20*outputs_lin[frame,:])
var_linpf = np.var(20*outputs[frame,:]-20*outputs_linpf[frame,:])
print("frame: %d var_lin: %3.2f " % (frame,var_lin), end='')
if nn:
var_nnest = np.var(20*outputs[frame,:]-20*outputs_nnest[frame,:])
print("var_nnest: %3.2f" % (var_nnest), end='')
else:
print("var_linpf: %3.2f" % (var_linpf), end='')
print(flush=True)
plt.show(block=False)
loop = plt.waitforbuttonpress(0)
frame += 1
plt.clf()