facetrain.c 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #include <math.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <unistd.h>
  5. #include "backprop.h"
  6. #include "omp.h"
  7. extern char *strcpy();
  8. extern void exit();
  9. int layer_size = 0;
  10. int platform_id = 0;
  11. int device_id = 0;
  12. int use_gpu = 0;
  13. void backprop_face()
  14. {
  15. BPNN *net;
  16. int i;
  17. float out_err, hid_err;
  18. net = bpnn_create(layer_size, 16, 1); // (16, 1 can not be changed)
  19. printf("Input layer size : %d\n", layer_size);
  20. load(net);
  21. //entering the training kernel, only one iteration
  22. printf("Starting training kernel\n");
  23. bpnn_train_kernel(net, &out_err, &hid_err, platform_id, device_id, use_gpu);
  24. bpnn_free(net);
  25. printf("\nFinish the training for one iteration\n");
  26. }
  27. void Usage(char *argv0){
  28. char *help =
  29. "\nUsage: %s [switches] \n\n"
  30. " -l :layer size \n"
  31. " -p platform_id :OCL platform to use [default=0]\n"
  32. " -d device_id :OCL device to use [default=0]\n"
  33. " -g use_gpu :1 for GPU 0 for CPU [default=0]\n";
  34. fprintf(stderr, help, argv0);
  35. exit(-1);
  36. }
  37. int setup(int argc, char **argv)
  38. {
  39. int seed;
  40. int opt;
  41. extern char *optarg;
  42. while ((opt=getopt(argc,argv,"l:p:d:g:"))!= EOF) {
  43. switch (opt) {
  44. case 'p': platform_id = atoi(optarg);
  45. break;
  46. case 'd': device_id = atoi(optarg);
  47. break;
  48. case 'g': use_gpu = atoi(optarg);
  49. break;
  50. case 'l': layer_size = atoi(optarg);
  51. break;
  52. case '?': Usage(argv[0]);
  53. break;
  54. default: Usage(argv[0]);
  55. break;
  56. }
  57. }
  58. if (layer_size%16!=0){
  59. fprintf(stderr, "The number of input points must be divided by 16\n");
  60. exit(0);
  61. }
  62. seed = 7;
  63. bpnn_initialize(seed);
  64. backprop_face();
  65. exit(0);
  66. }