backprop.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #ifndef _BACKPROP_H_
  2. #define _BACKPROP_H_
  3. #define BIGRND 0x7fffffff
  4. #define THREADS 256
  5. #define WIDTH 16 // shared memory width
  6. #define HEIGHT 16 // shared memory height
  7. #define BLOCK_SIZE 16
  8. #define ETA 0.3 //eta value
  9. #define MOMENTUM 0.3 //momentum value
  10. #define NUM_THREAD 4 //OpenMP threads
  11. typedef struct {
  12. int input_n; /* number of input units */
  13. int hidden_n; /* number of hidden units */
  14. int output_n; /* number of output units */
  15. float *input_units; /* the input units */
  16. float *hidden_units; /* the hidden units */
  17. float *output_units; /* the output units */
  18. float *hidden_delta; /* storage for hidden unit error */
  19. float *output_delta; /* storage for output unit error */
  20. float *target; /* storage for target vector */
  21. float **input_weights; /* weights from input to hidden layer */
  22. float **hidden_weights; /* weights from hidden to output layer */
  23. /*** The next two are for momentum ***/
  24. float **input_prev_weights; /* previous change on input to hidden wgt */
  25. float **hidden_prev_weights; /* previous change on hidden to output wgt */
  26. } BPNN;
  27. /*** User-level functions ***/
  28. //void bpnn_initialize();
  29. void bpnn_initialize(int seed);
  30. BPNN *bpnn_create(int n_in, int n_hidden, int n_out);
  31. void bpnn_free(BPNN *net);
  32. //BPNN *bpnn_create();
  33. //void bpnn_free();
  34. void bpnn_train(BPNN *net, float *eo, float *eh);
  35. //void bpnn_train();
  36. //void bpnn_feedforward();
  37. void bpnn_feedforward(BPNN *net);
  38. void bpnn_save(BPNN *net, char *filename);
  39. //void bpnn_save();
  40. //BPNN *bpnn_read();
  41. BPNN *bpnn_read(char *filename);
  42. void load(BPNN *net);
  43. int bpnn_train_kernel(BPNN *net, float *eo, float *eh, int platform_num, int device_num, int use_gpu);
  44. void bpnn_layerforward(float *l1, float *l2, float **conn, int n1, int n2);
  45. void bpnn_output_error(float *delta, float *target, float *output, int nj, float *err);
  46. void bpnn_hidden_error(float *delta_h, int nh, float *delta_o, int no, float **who, float *hidden, float *err);
  47. void bpnn_adjust_weights(float *delta, int ndelta, float *ly, int nly, float **w, float **oldw);
  48. int setup(int argc, char** argv);
  49. float **alloc_2d_dbl(int m, int n);
  50. float squash(float x);
  51. #endif