提交 121061b6 作者: Juan Batiz-Benet

Merge pull request #607 from jbenet/races

races: fix race conditions
...@@ -152,7 +152,7 @@ ...@@ -152,7 +152,7 @@
}, },
{ {
"ImportPath": "github.com/jbenet/go-peerstream", "ImportPath": "github.com/jbenet/go-peerstream",
"Rev": "ccc044c2a5999f36743881ff73568660a581f2f2" "Rev": "530b09b2300da11cc19f479289be5d014c146581"
}, },
{ {
"ImportPath": "github.com/jbenet/go-random", "ImportPath": "github.com/jbenet/go-random",
......
language: go
go:
- 1.2
- 1.3
- 1.4
- release
- tip
script:
- go test -race -cpu=5 -v ./...
package main
import (
"fmt"
"net"
"os"
"time"
ps "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-peerstream"
)
func die(err error) {
fmt.Fprintf(os.Stderr, "error: %s\n")
os.Exit(1)
}
func main() {
// create a new Swarm
swarm := ps.NewSwarm()
defer swarm.Close()
// tell swarm what to do with a new incoming streams.
// EchoHandler just echos back anything they write.
swarm.SetStreamHandler(ps.EchoHandler)
l, err := net.Listen("tcp", "localhost:8001")
if err != nil {
die(err)
}
if _, err := swarm.AddListener(l); err != nil {
die(err)
}
nc, err := net.Dial("tcp", "localhost:8001")
if err != nil {
die(err)
}
c, err := swarm.AddConn(nc)
if err != nil {
die(err)
}
hello := []byte("hello")
goodbye := []byte("goodbye")
swarm.SetStreamHandler(func(s *ps.Stream) {
go func() {
log("handler: got new stream.")
// s.Wait()
// log("handler: done waiting on new stream.")
buf := make([]byte, len(hello))
s.Read(buf)
log("handler: read: %s", buf)
s.Write(goodbye)
log("handler: wrote: %s", goodbye)
s.Close()
log("handler: closed.")
}()
})
for {
s, err := swarm.NewStreamWithConn(c)
if err != nil {
die(err)
}
// s.Wait()
log("sender: got new stream")
for {
<-time.After(500 * time.Millisecond)
log("sender: writing hello...")
if _, err := s.Write(hello); err != nil {
log("sender: write error: %s", err)
break
}
buf := make([]byte, len(goodbye))
if _, err := s.Read(buf); err != nil {
log("sender: read error: %s", err)
break
}
}
if err := s.Close(); err != nil {
log("sender: close error: %s", err)
}
}
}
func log(s string, ifs ...interface{}) {
fmt.Fprintf(os.Stderr, s+"\n", ifs...)
}
...@@ -325,20 +325,20 @@ func (s *Swarm) Close() error { ...@@ -325,20 +325,20 @@ func (s *Swarm) Close() error {
var wgl sync.WaitGroup var wgl sync.WaitGroup
for _, l := range s.Listeners() { for _, l := range s.Listeners() {
wgl.Add(1) wgl.Add(1)
go func() { go func(list *Listener) {
l.Close() list.Close()
wgl.Done() wgl.Done()
}() }(l)
} }
wgl.Wait() wgl.Wait()
var wgc sync.WaitGroup var wgc sync.WaitGroup
for _, c := range s.Conns() { for _, c := range s.Conns() {
wgc.Add(1) wgc.Add(1)
go func() { go func(conn *Conn) {
c.Close() conn.Close()
wgc.Done() wgc.Done()
}() }(c)
} }
wgc.Wait() wgc.Wait()
return nil return nil
......
...@@ -33,14 +33,31 @@ func (s *stream) Close() error { ...@@ -33,14 +33,31 @@ func (s *stream) Close() error {
} }
// Conn is a connection to a remote peer. // Conn is a connection to a remote peer.
type conn ss.Connection type conn struct {
sc *ss.Connection
closed chan struct{}
}
func (c *conn) spdyConn() *ss.Connection { func (c *conn) spdyConn() *ss.Connection {
return (*ss.Connection)(c) return c.sc
} }
func (c *conn) Close() error { func (c *conn) Close() error {
return c.spdyConn().Close() err := c.spdyConn().CloseWait()
if !c.IsClosed() {
close(c.closed)
}
return err
}
func (c *conn) IsClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
} }
// OpenStream creates a new stream. // OpenStream creates a new stream.
...@@ -84,6 +101,6 @@ type transport struct{} ...@@ -84,6 +101,6 @@ type transport struct{}
var Transport = transport{} var Transport = transport{}
func (t transport) NewConn(nc net.Conn, isServer bool) (pst.Conn, error) { func (t transport) NewConn(nc net.Conn, isServer bool) (pst.Conn, error) {
c, err := ss.NewConnection(nc, isServer) sc, err := ss.NewConnection(nc, isServer)
return (*conn)(c), err return &conn{sc: sc, closed: make(chan struct{})}, err
} }
...@@ -7,5 +7,6 @@ import ( ...@@ -7,5 +7,6 @@ import (
) )
func TestSpdyStreamTransport(t *testing.T) { func TestSpdyStreamTransport(t *testing.T) {
t.Skip("spdystream is known to be broken")
psttest.SubtestAll(t, Transport) psttest.SubtestAll(t, Transport)
} }
...@@ -44,11 +44,6 @@ func checkErr(t *testing.T, err error) { ...@@ -44,11 +44,6 @@ func checkErr(t *testing.T, err error) {
} }
} }
func getNextPort() int {
nextPort++
return nextPort
}
func log(s string, v ...interface{}) { func log(s string, v ...interface{}) {
if testing.Verbose() { if testing.Verbose() {
fmt.Fprintf(os.Stderr, "> "+s+"\n", v...) fmt.Fprintf(os.Stderr, "> "+s+"\n", v...)
...@@ -69,17 +64,15 @@ func singleConn(t *testing.T, tr pst.Transport) echoSetup { ...@@ -69,17 +64,15 @@ func singleConn(t *testing.T, tr pst.Transport) echoSetup {
log("closing stream") log("closing stream")
}) })
port := getNextPort() log("listening at %s", "localhost:0")
addr := fmt.Sprintf("localhost:%d", port) l, err := net.Listen("tcp", "localhost:0")
log("listening at %s", addr)
l, err := net.Listen("tcp", addr)
checkErr(t, err) checkErr(t, err)
_, err = swarm.AddListener(l) _, err = swarm.AddListener(l)
checkErr(t, err) checkErr(t, err)
log("dialing to %s", addr) log("dialing to %s", l.Addr())
nc1, err := net.Dial("tcp", addr) nc1, err := net.Dial("tcp", l.Addr().String())
checkErr(t, err) checkErr(t, err)
c1, err := swarm.AddConn(nc1) c1, err := swarm.AddConn(nc1)
...@@ -101,10 +94,8 @@ func makeSwarm(t *testing.T, tr pst.Transport, nListeners int) *ps.Swarm { ...@@ -101,10 +94,8 @@ func makeSwarm(t *testing.T, tr pst.Transport, nListeners int) *ps.Swarm {
}) })
for i := 0; i < nListeners; i++ { for i := 0; i < nListeners; i++ {
port := getNextPort() log("%p listening at %s", swarm, "localhost:0")
addr := fmt.Sprintf("localhost:%d", port) l, err := net.Listen("tcp", "localhost:0")
log("%p listening at %s", swarm, addr)
l, err := net.Listen("tcp", addr)
checkErr(t, err) checkErr(t, err)
_, err = swarm.AddListener(l) _, err = swarm.AddListener(l)
checkErr(t, err) checkErr(t, err)
...@@ -138,17 +129,15 @@ func SubtestSimpleWrite(t *testing.T, tr pst.Transport) { ...@@ -138,17 +129,15 @@ func SubtestSimpleWrite(t *testing.T, tr pst.Transport) {
log("closing stream") log("closing stream")
}) })
port := getNextPort() log("listening at %s", "localhost:0")
addr := fmt.Sprintf("localhost:%d", port) l, err := net.Listen("tcp", "localhost:0")
log("listening at %s", addr)
l, err := net.Listen("tcp", addr)
checkErr(t, err) checkErr(t, err)
_, err = swarm.AddListener(l) _, err = swarm.AddListener(l)
checkErr(t, err) checkErr(t, err)
log("dialing to %s", addr) log("dialing to %s", l.Addr().String())
nc1, err := net.Dial("tcp", addr) nc1, err := net.Dial("tcp", l.Addr().String())
checkErr(t, err) checkErr(t, err)
c1, err := swarm.AddConn(nc1) c1, err := swarm.AddConn(nc1)
......
...@@ -118,9 +118,11 @@ func TestAllKeysRespectsContext(t *testing.T) { ...@@ -118,9 +118,11 @@ func TestAllKeysRespectsContext(t *testing.T) {
// Once without context, to make sure it all works // Once without context, to make sure it all works
{ {
var results dsq.Results var results dsq.Results
var resultsmu = make(chan struct{})
resultChan := make(chan dsq.Result) resultChan := make(chan dsq.Result)
d.SetFunc(func(q dsq.Query) (dsq.Results, error) { d.SetFunc(func(q dsq.Query) (dsq.Results, error) {
results = dsq.ResultsWithChan(q, resultChan) results = dsq.ResultsWithChan(q, resultChan)
resultsmu <- struct{}{}
return results, nil return results, nil
}) })
...@@ -128,6 +130,7 @@ func TestAllKeysRespectsContext(t *testing.T) { ...@@ -128,6 +130,7 @@ func TestAllKeysRespectsContext(t *testing.T) {
// make sure it's waiting. // make sure it's waiting.
<-started <-started
<-resultsmu
select { select {
case <-done: case <-done:
t.Fatal("sync is wrong") t.Fatal("sync is wrong")
...@@ -156,9 +159,11 @@ func TestAllKeysRespectsContext(t *testing.T) { ...@@ -156,9 +159,11 @@ func TestAllKeysRespectsContext(t *testing.T) {
// Once with // Once with
{ {
var results dsq.Results var results dsq.Results
var resultsmu = make(chan struct{})
resultChan := make(chan dsq.Result) resultChan := make(chan dsq.Result)
d.SetFunc(func(q dsq.Query) (dsq.Results, error) { d.SetFunc(func(q dsq.Query) (dsq.Results, error) {
results = dsq.ResultsWithChan(q, resultChan) results = dsq.ResultsWithChan(q, resultChan)
resultsmu <- struct{}{}
return results, nil return results, nil
}) })
...@@ -167,6 +172,7 @@ func TestAllKeysRespectsContext(t *testing.T) { ...@@ -167,6 +172,7 @@ func TestAllKeysRespectsContext(t *testing.T) {
// make sure it's waiting. // make sure it's waiting.
<-started <-started
<-resultsmu
select { select {
case <-done: case <-done:
t.Fatal("sync is wrong") t.Fatal("sync is wrong")
......
...@@ -188,16 +188,16 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) { ...@@ -188,16 +188,16 @@ func SubtestConnSendDisc(t *testing.T, hosts []host.Host) {
defer wg.Done() defer wg.Done()
go sF(s) go sF(s)
log.Debugf("getting handle %d", i) log.Debugf("getting handle %d", j)
sc := <-ss // wait to get handle. sc := <-ss // wait to get handle.
log.Debugf("spawning worker %d", i) log.Debugf("spawning worker %d", j)
for i := 0; i < numMsgs; i++ { for k := 0; k < numMsgs; k++ {
sc.send <- struct{}{} sc.send <- struct{}{}
<-sc.sent <-sc.sent
log.Debugf("%d sent %d", j, i) log.Debugf("%d sent %d", j, k)
<-sc.read <-sc.read
log.Debugf("%d read %d", j, i) log.Debugf("%d read %d", j, k)
} }
sc.close_ <- struct{}{} sc.close_ <- struct{}{}
<-sc.closed <-sc.closed
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论