甘いもの詰め合わせ(ダイナミック・プログラミング)

プログラマのための論理パズルの解答例です。
"./a.out 1 160"で問題1の解答が、"./a.out 5 160"で問題2の解答が得られます。低性能と言って良いようなPC(netbook)で計算してますが、1問あたり5分以内で計算が終わりました。

ソースコード

#include <stdio.h>
#include <assert.h>
#include <string.h>
void cost_init(int* costs, int s1, int s2, int s3, int max) {
int i;
  assert(1 < s1 < max && 1 < s2 < max && 1 < s3 < max);
  for(i = 0; i < max; i++)
    costs[i] = max;
  costs[0] = costs[s1-1] = costs[s2-1] = costs[s3 -1] = 1;
}
void cost_set(int* costs, int max) {
int i;
  for(i = 1; i < max ; i++) {
    if(costs[i-1] + costs[max - i-1] < costs[max - 1]) {
      costs[max - 1] = costs[i-1] + costs[max - i - 1];
    }
  }
}
double average(int* n,int* w, int size) {
int tot = 0;
int tot2 = 0;
int i;

  for(i = 0; i < size; i++) {
    tot += n[i] * w[i];
    tot2 += w[i];
  }
  return ((double)tot)/tot2;
}
int str_to_int(const char* str) {
int n = 0;
  while(*str != '\0') {
    n = n * 10 + ((int)*str++ - '0');
  }
  return n;
}
double cost_average(int* w, int s1, int s2, int s3) {
int costs[160];
  cost_init(costs, s1, s2, s3, 160);
int i;
  for(i = 2; i <= 160; i++)
    cost_set(costs, i);
  return average(costs, w, 160);
}

void test(void);

int main(int argc, char* argv[]) {
  if(argc == 2 && !strcmp(argv[1], "test")) {
    test();
    return 0;
  }
  if(argc != 3) {
    fprintf(stderr, "usage:%s <weight> <iteration>\n", argv[0]);
    return 1;
  }

int w[160];
int n;
for(n = 0; n < 50; n++)
  w[n] = str_to_int(argv[1]);
for(n = 50; n < 160; n++)
  w[n] = 1;

int size = str_to_int(argv[2]);
assert(size > 3);
double ave = 160;
char str[20];
int i, j, k;
  for(i = 2; i <= size-2; i++) {
    for(j = i+1; j <= size-1; j++) {
      for(k = j+1; k <= size; k++) {
        double tave = cost_average(w, i, j, k);
        if(ave > tave) {
          ave = tave;
	  sprintf(str, "(1, %d, %d, %d)", i, j, k);
        }
      }
    }
  }
  printf("%s\n", str);
  printf("%.2f\n", ave);
  return 0;
}

void test_init(void) {
int max = 160;
int costs[160];
 cost_init(costs, 5, 10, 20, max);
 assert(1 == costs[0] && 1 == costs[4] && 
        1 == costs[9] && 1 == costs[19]);
 costs[0] = costs[4] = costs[9] = costs[19] = max;
 int i;
 for(i = 0; i < 160; i++)
   assert(max == costs[i]);
}
void test_set(void) {
int costs[160];
  cost_init(costs, 5, 10, 20, 160);
  cost_set(costs, 2);
  assert(2 == costs[1]);
  cost_set(costs, 3);
  assert(3 == costs[2]);

  cost_init(costs, 2, 10, 20, 160);
  cost_set(costs, 2); 
  assert(1 == costs[1]);
  cost_set(costs, 3);
  assert(2 == costs[2]);


  cost_init(costs, 3, 10, 20, 160);
  cost_set(costs, 2); 
  assert(2 == costs[1]);
  cost_set(costs, 3);
  assert(1 == costs[2]);

  cost_init(costs, 5, 10, 20, 160);
int i;
  for(i = 2; i <= 160; i++)
    cost_set(costs, i); 
  assert(1 == costs[0] && 1 == costs[4] &&
         1 == costs[9] && 1 == costs[19]);
  assert(4 == costs[7]);
  assert(6 == costs[52]); 
}
void test_average(void) {
  int n[4] = {1, 2, 3, 4};
  int w[4] = {1, 1, 1, 1};
  double ave = average(n, w, 4);
  assert(2.5 == ave);
}
void test_str_to_int(void) {
  int num;
  assert(123 == str_to_int("123"));
  assert(34 == str_to_int("34"));
  assert(2011 == str_to_int("2011"));
}
void test(void) {
  test_init();
  test_set();
  test_average();
  test_str_to_int();
}