素数の篩への道

golang.jpにある並行処理の例、素数の篩(ふるい)ボトムアップに作ってみます。まず、

package main

import "fmt"

func main() {
        ch := make(chan int)
        go func() {
                for i := 0; ; i++ {
                        ch <- i
                }
        }()
        result := ""
        for i := 0; i < 10; i++ {
                result += fmt.Sprintf("%d, ", <-ch)
        }
        fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 2, 3, 4, 5, 6, 7, 8, 9\n"
}

で、chを通して整数値がやりとりされるのでした。もう1つチャンネル(ch1)を増やして奇数だけ通すフィルタを作ることは簡単です。

@@ -9,9 +9,18 @@
                        ch <- i
                }
        }()
+       ch2 := make(chan int)
+       go func() {
+               for {
+                       i := <-ch
+                       if i%2 != 0 {
+                               ch2 <- i
+                       }
+               }
+       }()
        result := ""
        for i := 0; i < 10; i++ {
-               result += fmt.Sprintf("%d, ", <-ch)
+               result += fmt.Sprintf("%d, ", <-ch2)
        }
-       fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 2, 3, 4, 5, 6, 7, 8, 9\n"
+       fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 3, 5, 7, 9, 11, 13, 15, 17, 19\n"
 }

もういっちょ、フィルタリングして(チャンネルch3を入れる)3の倍数でない奇数を生成してみます。

@@ -18,9 +18,19 @@
                        }
                }
        }()
+       ch3 := make(chan int)
+       go func() {
+               for {
+                       i := <-ch2
+                       if i%3 != 0 {
+                               ch3 <- i
+                       }
+               }
+       }()
+
        result := ""
        for i := 0; i < 10; i++ {
-               result += fmt.Sprintf("%d, ", <-ch2)
+               result += fmt.Sprintf("%d, ", <-ch3)
        }
-       fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 3, 5, 7, 9, 11, 13, 15, 17, 19\n"
+       fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 5, 7, 11, 13, 17, 19, 23, 25, 29\n"
 }

これで、各チャンネルでフィルタリングしてる様子が分かったと思います(ぼくは分かった気分になりました)。じゃあ、フィルタの数を増やす準備としてフィルタリング処理を関数として括り出します。

package main

import "fmt"

func filter(in chan int, prime int) chan int{
	out := make(chan int)
	go func() {
		for {
			i := <-in
			if i%prime != 0 {
				out <- i
			}
		}
	}()
	return out
}

func main() {
	ch := make(chan int)
	go func() {
		for i := 0; ; i++ {
			ch <- i
		}
	}()
	ch2 := filter(ch, 2)
	ch3 := filter(ch2, 3)
	result := ""
	for i := 0; i < 10; i++ {
		result += fmt.Sprintf("%d, ", <-ch3)
	}
	fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 5, 7, 11, 13, 17, 19, 23, 25, 29\n"
}

ところで、もともとの目的は素数を得ることでした。素数は各フィルタのチャンネルが得る最初の数です。素数を標準出力するように書き直します。

 func main() {
        ch := make(chan int)
        go func() {
-               for i := 0; ; i++ {
+               for i := 2; ; i++ {
                        ch <- i
                }
        }()
+       fmt.Println(<-ch)//=>2
        ch2 := filter(ch, 2)
+       fmt.Println(<-ch2)//=>3
        ch3 := filter(ch2, 3)
-       result := ""
-       for i := 0; i < 10; i++ {
-               result += fmt.Sprintf("%d, ", <-ch3)
-       }
-       fmt.Printf("%s\n", result[:len(result)-2]) //=> "1, 5, 7, 11, 13, 17, 19, 23, 25, 29\n"
+       fmt.Println(<-ch3)//=>5
 }

いよいよforループでくくれます。

@@ -22,9 +22,10 @@
 			ch <- i
 		}
 	}()
-	fmt.Println(<-ch)//=>2
-	ch2 := filter(ch, 2)
-	fmt.Println(<-ch2)//=>3
-	ch3 := filter(ch2, 3)
-	fmt.Println(<-ch3)//=>5
+	for i:=0; i< 3;i++ {
+		prime := <-ch
+		fmt.Println(prime)
+		ch = filter(ch, prime)
+	}
+	//=>"2\n3\4\n"
 }

しかし、"2 3 4"となり、上で生成された整数を受け取ってしまいます。これは整数の生成を関数にすれば解決します。

package main

import "fmt"

func filter(in chan int, prime int) chan int{
	out := make(chan int)
	go func() {
		for {
			i := <-in
			if i%prime != 0 {
				out <- i
			}
		}
	}()
	return out
}
func generate() chan int{
	ch := make(chan int)
	go func() {
		for i := 2; ; i++ {
			ch <- i
		}
	}()
	return ch
}
func main() {
	ch := generate()
	for i:=0; i< 3; i++ {
		prime := <-ch
		fmt.Println(prime)
		ch = filter(ch, prime)
	}
	//=> "2\n3\n5"
}

golang.jpに合わせて最後に素数を生成する関数を作って終わります。

@@ -23,12 +23,21 @@
 	}()
 	return ch
 }
-func main() {
+func sieve() chan int {
+  out := make(chan int)
+  go func() {
 	ch := generate()
-	for i:=0; i< 3; i++ {
+	for {
 		prime := <-ch
-		fmt.Println(prime)
+		out <- prime
 		ch = filter(ch, prime)
 	}
-	//=> "2\n3\n5"
+  }()
+  return out
+}
+func main() {
+	primes := sieve()
+	for i:=0; i<100; i++ {
+		fmt.Println(<-primes)
+	}
 }