handle sparse gru with fixed length

pull/9/head
David 2019-08-02 07:56:56 +09:30
parent 810229b285
commit 382b54cee6
2 changed files with 6 additions and 19 deletions

View File

@ -1,7 +1,7 @@
/*
nnet2f32.c
Writes current compiled-in model to a binary file of floats.
Writes current compiled-in model to a binary file of floats, and runs a few tests.
*/
#include <assert.h>

View File

@ -9,6 +9,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include "nnet_data.h"
#include "nnet_rw.h"
@ -177,25 +178,11 @@ void read_gru_weights(char *name, const GRULayer *l, FILE *f32) {
ret = fread(l->recurrent_weights, sizeof(float), nrecurrent, f32); assert(ret == nrecurrent);
}
int sparse_sgemv_count_idx(int rows, const int *idx)
{
int i, j;
int count = 0;
for (i=0;i<rows;i+=16) {
int cols;
cols = *idx++; count++;
for (j=0;j<cols;j++) {
idx++; count++;
}
}
return count;
}
void write_sparse_gru_weights(char *name, const SparseGRULayer *l, FILE *f32) {
int nbias = l->nb_neurons*6;
int ndiag = l->nb_neurons*3;
int nrecurrent = l->nb_neurons*l->nb_neurons*3;
int nidx = sparse_sgemv_count_idx(ndiag, l->idx);
int nidx = 32767;
printf("%s: %d %d %d %d\n", name, nbias, ndiag, nrecurrent, nidx);
fwrite(l->bias, sizeof(float), nbias, f32);
fwrite(l->diag_weights, sizeof(float), ndiag, f32);
@ -207,7 +194,7 @@ void check_sparse_gru_weights(char *name, const SparseGRULayer *l, FILE *f32) {
int nbias = l->nb_neurons*6;
int ndiag = l->nb_neurons*3;
int nrecurrent = l->nb_neurons*l->nb_neurons*3;
int nidx = sparse_sgemv_count_idx(ndiag, l->idx);
int nidx = 32767;
printf("%s: %d %d %d %d", name, nbias, ndiag, nrecurrent, nidx);
check(l->bias, nbias, f32);
check(l->diag_weights, ndiag, f32);
@ -220,9 +207,9 @@ void read_sparse_gru_weights(char *name, const SparseGRULayer *l, FILE *f32) {
int nbias = l->nb_neurons*6;
int ndiag = l->nb_neurons*3;
int nrecurrent = l->nb_neurons*l->nb_neurons*3;
int nidx = sparse_sgemv_count_idx(ndiag, l->idx);
printf("%s: %d %d %d %d\n", name, nbias, ndiag, nrecurrent, nidx);
int ret;
int nidx = 32767;
printf("%s: %d %d %d %d\n", name, nbias, ndiag, nrecurrent, nidx);
ret = fread(l->bias, sizeof(float), nbias, f32); assert(ret == nbias);
ret = fread(l->diag_weights, sizeof(float), ndiag, f32); assert(ret == ndiag);
ret = fread(l->recurrent_weights, sizeof(float), nrecurrent, f32); assert(ret == nrecurrent);