package main import ( "fmt" "sync" ) // gen sends the values in nums on the returned channel, then closes it. func gen(nums ...int) <-chan int { out := make(chan int, len(nums)) for _, n := range nums { out <- n } close(out) return out } // sq receives values from in, squares them, and sends them on the returned // channel, until in is closed. Then sq closes the returned channel. func sq(in <-chan int) <-chan int { out := make(chan int) go func() { for n := range in { out <- n * n } close(out) }() return out } // merge receives values from each input channel and sends them on the returned // channel. merge closes the returned channel after all the input values have // been sent. func merge(cs ...<-chan int) <-chan int { var wg sync.WaitGroup out := make(chan int, 1) // enough space for the unread inputs // ... the rest is unchanged ... // Start an output goroutine for each input channel in cs. output // copies values from c to out until c is closed, then calls wg.Done. output := func(c <-chan int) { for n := range c { out <- n } wg.Done() } wg.Add(len(cs)) for _, c := range cs { go output(c) } // Start a goroutine to close out once all the output goroutines are // done. This must start after the wg.Add call. go func() { wg.Wait() close(out) }() return out } func main() { in := gen(2, 3) // Distribute the sq work across two goroutines that both read from in. c1 := sq(in) c2 := sq(in) // Consume the first value from output. out := merge(c1, c2) fmt.Println(<-out) // 4 or 9 return // The second value is sent into out's buffer, and all goroutines exit. }