提交 5341379f 作者: Juan Batiz-Benet

net/conn: io, not channels

This commit changes the connections to use io.ReadWriters
instead of channels (+ async workers). This is a pretty
big change -- away from csp -- in the name of performance
(and predictable flow control).

It also uses the brand new secio, which is spipe's successor.
上级 48bf4461
...@@ -18,9 +18,6 @@ import ( ...@@ -18,9 +18,6 @@ import (
var log = u.Logger("conn") var log = u.Logger("conn")
const ( const (
// ChanBuffer is the size of the buffer in the Conn Chan
ChanBuffer = 10
// MaxMessageSize is the size of the largest single message. (4MB) // MaxMessageSize is the size of the largest single message. (4MB)
MaxMessageSize = 1 << 22 MaxMessageSize = 1 << 22
...@@ -35,25 +32,12 @@ func ReleaseBuffer(b []byte) { ...@@ -35,25 +32,12 @@ func ReleaseBuffer(b []byte) {
mpool.ByteSlicePool.Put(uint32(cap(b)), b) mpool.ByteSlicePool.Put(uint32(cap(b)), b)
} }
// msgioPipe is a pipe using msgio channels.
type msgioPipe struct {
outgoing *msgio.Chan
incoming *msgio.Chan
}
func newMsgioPipe(size int) *msgioPipe {
return &msgioPipe{
outgoing: msgio.NewChan(size),
incoming: msgio.NewChan(size),
}
}
// singleConn represents a single connection to another Peer (IPFS Node). // singleConn represents a single connection to another Peer (IPFS Node).
type singleConn struct { type singleConn struct {
local peer.Peer local peer.Peer
remote peer.Peer remote peer.Peer
maconn manet.Conn maconn manet.Conn
msgio *msgioPipe msgrw msgio.ReadWriteCloser
ctxc.ContextCloser ctxc.ContextCloser
} }
...@@ -66,24 +50,12 @@ func newSingleConn(ctx context.Context, local, remote peer.Peer, ...@@ -66,24 +50,12 @@ func newSingleConn(ctx context.Context, local, remote peer.Peer,
local: local, local: local,
remote: remote, remote: remote,
maconn: maconn, maconn: maconn,
msgio: newMsgioPipe(10), msgrw: msgio.NewReadWriter(maconn),
} }
conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close) conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close)
log.Debugf("newSingleConn: %v to %v", local, remote) log.Debugf("newSingleConn %p: %v to %v", conn, local, remote)
// setup the various io goroutines
conn.Children().Add(1)
go func() {
conn.msgio.outgoing.WriteTo(maconn)
conn.Children().Done()
}()
conn.Children().Add(1)
go func() {
conn.msgio.incoming.ReadFromWithPool(maconn, &mpool.ByteSlicePool)
conn.Children().Done()
}()
// version handshake // version handshake
ctxT, _ := context.WithTimeout(ctx, HandshakeTimeout) ctxT, _ := context.WithTimeout(ctx, HandshakeTimeout)
...@@ -92,6 +64,7 @@ func newSingleConn(ctx context.Context, local, remote peer.Peer, ...@@ -92,6 +64,7 @@ func newSingleConn(ctx context.Context, local, remote peer.Peer,
return nil, fmt.Errorf("Handshake1 failed: %s", err) return nil, fmt.Errorf("Handshake1 failed: %s", err)
} }
log.Debugf("newSingleConn %p: %v to %v finished", conn, local, remote)
return conn, nil return conn, nil
} }
...@@ -100,20 +73,7 @@ func (c *singleConn) close() error { ...@@ -100,20 +73,7 @@ func (c *singleConn) close() error {
log.Debugf("%s closing Conn with %s", c.local, c.remote) log.Debugf("%s closing Conn with %s", c.local, c.remote)
// close underlying connection // close underlying connection
err := c.maconn.Close() return c.msgrw.Close()
c.msgio.outgoing.Close()
return err
}
func (c *singleConn) GetError() error {
select {
case err := <-c.msgio.incoming.ErrChan:
return err
case err := <-c.msgio.outgoing.ErrChan:
return err
default:
return nil
}
} }
// ID is an identifier unique to this connection. // ID is an identifier unique to this connection.
...@@ -145,14 +105,29 @@ func (c *singleConn) RemotePeer() peer.Peer { ...@@ -145,14 +105,29 @@ func (c *singleConn) RemotePeer() peer.Peer {
return c.remote return c.remote
} }
// In returns a readable message channel // Read reads data, net.Conn style
func (c *singleConn) In() <-chan []byte { func (c *singleConn) Read(buf []byte) (int, error) {
return c.msgio.incoming.MsgChan return c.msgrw.Read(buf)
}
// Write writes data, net.Conn style
func (c *singleConn) Write(buf []byte) (int, error) {
return c.msgrw.Write(buf)
}
// ReadMsg reads data, net.Conn style
func (c *singleConn) ReadMsg() ([]byte, error) {
return c.msgrw.ReadMsg()
}
// WriteMsg writes data, net.Conn style
func (c *singleConn) WriteMsg(buf []byte) error {
return c.msgrw.WriteMsg(buf)
} }
// Out returns a writable message channel // ReleaseMsg releases a buffer
func (c *singleConn) Out() chan<- []byte { func (c *singleConn) ReleaseMsg(m []byte) {
return c.msgio.outgoing.MsgChan c.msgrw.ReleaseMsg(m)
} }
// ID returns the ID of a given Conn. // ID returns the ID of a given Conn.
...@@ -167,6 +142,6 @@ func ID(c Conn) string { ...@@ -167,6 +142,6 @@ func ID(c Conn) string {
// String returns the user-friendly String representation of a conn // String returns the user-friendly String representation of a conn
func String(c Conn, typ string) string { func String(c Conn, typ string) string {
return fmt.Sprintf("%s (%s) <-- %s --> (%s) %s", return fmt.Sprintf("%s (%s) <-- %s %p --> (%s) %s",
c.LocalPeer(), c.LocalMultiaddr(), typ, c.RemoteMultiaddr(), c.RemotePeer()) c.LocalPeer(), c.LocalMultiaddr(), typ, c, c.RemoteMultiaddr(), c.RemotePeer())
} }
...@@ -100,18 +100,24 @@ func TestCloseLeak(t *testing.T) { ...@@ -100,18 +100,24 @@ func TestCloseLeak(t *testing.T) {
c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/"+a1, "/ip4/127.0.0.1/tcp/"+a2) c1, c2 := setupConn(t, ctx, "/ip4/127.0.0.1/tcp/"+a1, "/ip4/127.0.0.1/tcp/"+a2)
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
b1 := []byte("beep") b1 := []byte(fmt.Sprintf("beep%d", i))
c1.Out() <- b1 c1.WriteMsg(b1)
b2 := <-c2.In() b2, err := c2.ReadMsg()
if err != nil {
panic(err)
}
if !bytes.Equal(b1, b2) { if !bytes.Equal(b1, b2) {
panic("bytes not equal") panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2))
} }
b2 = []byte("boop") b2 = []byte(fmt.Sprintf("boop%d", i))
c2.Out() <- b2 c2.WriteMsg(b2)
b1 = <-c1.In() b1, err = c1.ReadMsg()
if err != nil {
panic(err)
}
if !bytes.Equal(b1, b2) { if !bytes.Equal(b1, b2) {
panic("bytes not equal") panic(fmt.Errorf("bytes not equal: %s != %s", b1, b2))
} }
<-time.After(time.Microsecond * 5) <-time.After(time.Microsecond * 5)
...@@ -123,7 +129,7 @@ func TestCloseLeak(t *testing.T) { ...@@ -123,7 +129,7 @@ func TestCloseLeak(t *testing.T) {
wg.Done() wg.Done()
} }
var cons = 20 var cons = 1
var msgs = 100 var msgs = 100
fmt.Printf("Running %d connections * %d msgs.\n", cons, msgs) fmt.Printf("Running %d connections * %d msgs.\n", cons, msgs)
for i := 0; i < cons; i++ { for i := 0; i < cons; i++ {
......
...@@ -72,5 +72,6 @@ func (d *Dialer) DialAddr(ctx context.Context, raddr ma.Multiaddr, remote peer.P ...@@ -72,5 +72,6 @@ func (d *Dialer) DialAddr(ctx context.Context, raddr ma.Multiaddr, remote peer.P
return nil, err return nil, err
} }
// return c, nil
return newSecureConn(ctx, c, d.Peerstore) return newSecureConn(ctx, c, d.Peerstore)
} }
package conn package conn
import ( import (
"io"
"testing" "testing"
ci "github.com/jbenet/go-ipfs/crypto" ci "github.com/jbenet/go-ipfs/crypto"
...@@ -42,14 +43,7 @@ func echoListen(ctx context.Context, listener Listener) { ...@@ -42,14 +43,7 @@ func echoListen(ctx context.Context, listener Listener) {
} }
func echo(ctx context.Context, c Conn) { func echo(ctx context.Context, c Conn) {
for { io.Copy(c, c)
select {
case <-ctx.Done():
return
case m := <-c.In():
c.Out() <- m
}
}
} }
func setupConn(t *testing.T, ctx context.Context, a1, a2 string) (a, b Conn) { func setupConn(t *testing.T, ctx context.Context, a1, a2 string) (a, b Conn) {
...@@ -137,17 +131,25 @@ func TestDialer(t *testing.T) { ...@@ -137,17 +131,25 @@ func TestDialer(t *testing.T) {
} }
// fmt.Println("sending") // fmt.Println("sending")
c.Out() <- []byte("beep") c.WriteMsg([]byte("beep"))
c.Out() <- []byte("boop") c.WriteMsg([]byte("boop"))
out, err := c.ReadMsg()
if err != nil {
t.Fatal(err)
}
out := <-c.In()
// fmt.Println("recving", string(out)) // fmt.Println("recving", string(out))
data := string(out) data := string(out)
if data != "beep" { if data != "beep" {
t.Error("unexpected conn output", data) t.Error("unexpected conn output", data)
} }
out = <-c.In() out, err = c.ReadMsg()
if err != nil {
t.Fatal(err)
}
data = string(out) data = string(out)
if string(out) != "boop" { if string(out) != "boop" {
t.Error("unexpected conn output", data) t.Error("unexpected conn output", data)
...@@ -207,17 +209,24 @@ func TestDialAddr(t *testing.T) { ...@@ -207,17 +209,24 @@ func TestDialAddr(t *testing.T) {
} }
// fmt.Println("sending") // fmt.Println("sending")
c.Out() <- []byte("beep") c.WriteMsg([]byte("beep"))
c.Out() <- []byte("boop") c.WriteMsg([]byte("boop"))
out := <-c.In() out, err := c.ReadMsg()
if err != nil {
t.Fatal(err)
}
// fmt.Println("recving", string(out)) // fmt.Println("recving", string(out))
data := string(out) data := string(out)
if data != "beep" { if data != "beep" {
t.Error("unexpected conn output", data) t.Error("unexpected conn output", data)
} }
out = <-c.In() out, err = c.ReadMsg()
if err != nil {
t.Fatal(err)
}
data = string(out) data = string(out)
if string(out) != "boop" { if string(out) != "boop" {
t.Error("unexpected conn output", data) t.Error("unexpected conn output", data)
......
package conn package conn
import ( import (
"errors"
"fmt" "fmt"
handshake "github.com/jbenet/go-ipfs/net/handshake" handshake "github.com/jbenet/go-ipfs/net/handshake"
...@@ -25,29 +24,22 @@ func Handshake1(ctx context.Context, c Conn) error { ...@@ -25,29 +24,22 @@ func Handshake1(ctx context.Context, c Conn) error {
return err return err
} }
c.Out() <- myVerBytes if err := CtxWriteMsg(ctx, c, myVerBytes); err != nil {
log.Debugf("Sent my version (%s) to %s", localH, rpeer) return err
}
select { log.Debugf("%p sent my version (%s) to %s", c, localH, rpeer)
case <-ctx.Done():
return ctx.Err()
case <-c.Closing():
return errors.New("remote closed connection during version exchange")
case data, ok := <-c.In():
if !ok {
return fmt.Errorf("error retrieving from conn: %v", rpeer)
}
remoteH = new(hspb.Handshake1) data, err := CtxReadMsg(ctx, c)
err = proto.Unmarshal(data, remoteH) if err != nil {
if err != nil { return err
return fmt.Errorf("could not decode remote version: %q", err) }
}
log.Debugf("Received remote version (%s) from %s", remoteH, rpeer) remoteH = new(hspb.Handshake1)
err = proto.Unmarshal(data, remoteH)
if err != nil {
return fmt.Errorf("could not decode remote version: %q", err)
} }
log.Debugf("%p received remote version (%s) from %s", c, remoteH, rpeer)
if err := handshake.Handshake1Compatible(localH, remoteH); err != nil { if err := handshake.Handshake1Compatible(localH, remoteH); err != nil {
log.Infof("%s (%s) incompatible version with %s (%s)", lpeer, localH, rpeer, remoteH) log.Infof("%s (%s) incompatible version with %s (%s)", lpeer, localH, rpeer, remoteH)
...@@ -71,31 +63,25 @@ func Handshake3(ctx context.Context, c Conn) (*handshake.Handshake3Result, error ...@@ -71,31 +63,25 @@ func Handshake3(ctx context.Context, c Conn) (*handshake.Handshake3Result, error
return nil, err return nil, err
} }
c.Out() <- localB if err := CtxWriteMsg(ctx, c, localB); err != nil {
return nil, err
}
log.Debugf("Handshake1: sent to %s", rpeer) log.Debugf("Handshake1: sent to %s", rpeer)
// wait + listen for response // wait + listen for response
select { remoteB, err := CtxReadMsg(ctx, c)
case <-ctx.Done(): if err != nil {
return nil, ctx.Err() return nil, err
}
case <-c.Closing():
return nil, errors.New("Handshake3: error remote connection closed")
case remoteB, ok := <-c.In():
if !ok {
return nil, fmt.Errorf("Handshake3 error receiving from conn: %v", rpeer)
}
remoteH = new(hspb.Handshake3)
err = proto.Unmarshal(remoteB, remoteH)
if err != nil {
return nil, fmt.Errorf("Handshake3 could not decode remote msg: %q", err)
}
log.Debugf("Handshake3 received from %s", rpeer) remoteH = new(hspb.Handshake3)
err = proto.Unmarshal(remoteB, remoteH)
if err != nil {
return nil, fmt.Errorf("Handshake3 could not decode remote msg: %q", err)
} }
log.Debugf("Handshake3 received from %s", rpeer)
// actually update our state based on the new knowledge // actually update our state based on the new knowledge
res, err := handshake.Handshake3Update(lpeer, rpeer, remoteH) res, err := handshake.Handshake3Update(lpeer, rpeer, remoteH)
if err != nil { if err != nil {
......
package conn package conn
import ( import (
"errors"
peer "github.com/jbenet/go-ipfs/peer" peer "github.com/jbenet/go-ipfs/peer"
u "github.com/jbenet/go-ipfs/util" u "github.com/jbenet/go-ipfs/util"
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
msgio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio"
ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr"
) )
...@@ -31,15 +35,8 @@ type Conn interface { ...@@ -31,15 +35,8 @@ type Conn interface {
// RemotePeer is the Peer on the remote side // RemotePeer is the Peer on the remote side
RemotePeer() peer.Peer RemotePeer() peer.Peer
// In returns a readable message channel msgio.Reader
In() <-chan []byte msgio.Writer
// Out returns a writable message channel
Out() chan<- []byte
// Get an error from this conn if one is available
// TODO: implement a better error handling system
GetError() error
} }
// Dialer is an object that can open connections. We could have a "convenience" // Dialer is an object that can open connections. We could have a "convenience"
...@@ -77,3 +74,89 @@ type Listener interface { ...@@ -77,3 +74,89 @@ type Listener interface {
// Any blocked Accept operations will be unblocked and return errors. // Any blocked Accept operations will be unblocked and return errors.
Close() error Close() error
} }
// CtxRead is a function that Reads from a connection while respecting a
// Context. Though it cannot cancel the read per-se (as not all Connections
// implement SetTimeout, and a CancelFunc can't be predicted), at least it
// doesn't hang. The Read will eventually return and the goroutine will exit.
func CtxRead(ctx context.Context, c Conn, buf []byte) (n int, err error) {
done := make(chan struct{})
go func() {
n, err = c.Read(buf)
close(done)
}()
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-c.Closing():
return 0, errors.New("remote connection closed")
case <-done:
return n, err
}
}
// CtxReadMsg is a function that Reads from a connection while respecting a
// Context. See CtxRead.
func CtxReadMsg(ctx context.Context, c Conn) (msg []byte, err error) {
done := make(chan struct{})
go func() {
msg, err = c.ReadMsg()
close(done)
}()
select {
case <-ctx.Done():
return msg, ctx.Err()
case <-c.Closing():
return msg, errors.New("remote connection closed")
case <-done:
return msg, err
}
}
// CtxWrite is a function that Writes to a connection while respecting a
// Context. See CtxRead.
func CtxWrite(ctx context.Context, c Conn, buf []byte) (n int, err error) {
done := make(chan struct{})
go func() {
n, err = c.Read(buf)
close(done)
}()
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-c.Closing():
return 0, errors.New("remote connection closed")
case <-done:
return n, err
}
}
// CtxWriteMsg is a function that Writes to a connection while respecting a
// Context. See CtxRead.
func CtxWriteMsg(ctx context.Context, c Conn, buf []byte) (err error) {
done := make(chan struct{})
go func() {
err = c.WriteMsg(buf)
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-c.Closing():
return errors.New("remote connection closed")
case <-done:
return err
}
}
...@@ -65,12 +65,15 @@ func (l *listener) listen() { ...@@ -65,12 +65,15 @@ func (l *listener) listen() {
return return
} }
// if insecure:
// l.conns <- c
// if secure
sc, err := newSecureConn(l.ctx, c, l.peers) sc, err := newSecureConn(l.ctx, c, l.peers)
if err != nil { if err != nil {
log.Errorf("Error securing connection: %v", err) log.Errorf("Error securing connection: %v", err)
return return
} }
l.conns <- sc l.conns <- sc
} }
......
package conn package conn
import ( import (
"errors"
"sync" "sync"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
...@@ -14,12 +15,6 @@ import ( ...@@ -14,12 +15,6 @@ import (
// MultiConnMap is for shorthand // MultiConnMap is for shorthand
type MultiConnMap map[u.Key]*MultiConn type MultiConnMap map[u.Key]*MultiConn
// Duplex is a simple duplex channel
type Duplex struct {
In chan []byte
Out chan []byte
}
// MultiConn represents a single connection to another Peer (IPFS Node). // MultiConn represents a single connection to another Peer (IPFS Node).
type MultiConn struct { type MultiConn struct {
...@@ -30,8 +25,8 @@ type MultiConn struct { ...@@ -30,8 +25,8 @@ type MultiConn struct {
local peer.Peer local peer.Peer
remote peer.Peer remote peer.Peer
// fan-in/fan-out // fan-in
duplex Duplex fanIn chan []byte
// for adding/removing connections concurrently // for adding/removing connections concurrently
sync.RWMutex sync.RWMutex
...@@ -45,10 +40,7 @@ func NewMultiConn(ctx context.Context, local, remote peer.Peer, conns []Conn) (* ...@@ -45,10 +40,7 @@ func NewMultiConn(ctx context.Context, local, remote peer.Peer, conns []Conn) (*
local: local, local: local,
remote: remote, remote: remote,
conns: map[string]Conn{}, conns: map[string]Conn{},
duplex: Duplex{ fanIn: make(chan []byte),
In: make(chan []byte, 10),
Out: make(chan []byte, 10),
},
} }
// must happen before Adds / fanOut // must happen before Adds / fanOut
...@@ -58,8 +50,6 @@ func NewMultiConn(ctx context.Context, local, remote peer.Peer, conns []Conn) (* ...@@ -58,8 +50,6 @@ func NewMultiConn(ctx context.Context, local, remote peer.Peer, conns []Conn) (*
c.Add(conns...) c.Add(conns...)
} }
c.Children().Add(1)
go c.fanOut()
return c, nil return c, nil
} }
...@@ -135,38 +125,8 @@ func CloseConns(conns ...Conn) { ...@@ -135,38 +125,8 @@ func CloseConns(conns ...Conn) {
wg.Wait() wg.Wait()
} }
// fanOut is the multiplexor out -- it sends outgoing messages over the // fanInSingle Reads from a connection, and sends to the fanIn.
// underlying single connections. // waits for child to close and reclaims resources
func (c *MultiConn) fanOut() {
defer c.Children().Done()
i := 0
for {
select {
case <-c.Closing():
return
// send data out through our "best connection"
case m, more := <-c.duplex.Out:
if !more {
log.Debugf("%s out channel closed", c)
return
}
sc := c.BestConn()
if sc == nil {
// maybe this should be a logged error, not a panic.
panic("sending out multiconn without any live connection")
}
i++
log.Debugf("%s sending (%d)", sc, i)
sc.Out() <- m
}
}
}
// fanInSingle is a multiplexor in -- it receives incoming messages over the
// underlying single connections.
func (c *MultiConn) fanInSingle(child Conn) { func (c *MultiConn) fanInSingle(child Conn) {
// cleanup all data associated with this child Connection. // cleanup all data associated with this child Connection.
defer func() { defer func() {
...@@ -186,8 +146,13 @@ func (c *MultiConn) fanInSingle(child Conn) { ...@@ -186,8 +146,13 @@ func (c *MultiConn) fanInSingle(child Conn) {
} }
}() }()
i := 0
for { for {
msg, err := child.ReadMsg()
if err != nil {
log.Warning(err)
return
}
select { select {
case <-c.Closing(): // multiconn closing case <-c.Closing(): // multiconn closing
return return
...@@ -195,18 +160,7 @@ func (c *MultiConn) fanInSingle(child Conn) { ...@@ -195,18 +160,7 @@ func (c *MultiConn) fanInSingle(child Conn) {
case <-child.Closing(): // child closing case <-child.Closing(): // child closing
return return
case m, more := <-child.In(): // receiving data case c.fanIn <- msg:
if !more {
log.Debugf("%s in channel closed", child)
err := c.GetError()
if err != nil {
log.Errorf("Found error on connection: %s", err)
}
return // closed
}
i++
log.Debugf("%s received (%d)", child, i)
c.duplex.In <- m
} }
} }
} }
...@@ -296,24 +250,38 @@ func (c *MultiConn) RemotePeer() peer.Peer { ...@@ -296,24 +250,38 @@ func (c *MultiConn) RemotePeer() peer.Peer {
return c.remote return c.remote
} }
// In returns a readable message channel // Read reads data, net.Conn style
func (c *MultiConn) In() <-chan []byte { func (c *MultiConn) Read(buf []byte) (int, error) {
return c.duplex.In return 0, errors.New("multiconn does not support Read. use ReadMsg")
} }
// Out returns a writable message channel // Write writes data, net.Conn style
func (c *MultiConn) Out() chan<- []byte { func (c *MultiConn) Write(buf []byte) (int, error) {
return c.duplex.Out bc := c.BestConn()
if bc == nil {
return 0, errors.New("no best connection")
}
return bc.Write(buf)
} }
func (c *MultiConn) GetError() error { // ReadMsg reads data, net.Conn style
c.RLock() func (c *MultiConn) ReadMsg() ([]byte, error) {
defer c.RUnlock() next := <-c.fanIn
for _, sub := range c.conns { return next, nil
err := sub.GetError() }
if err != nil {
return err // WriteMsg writes data, net.Conn style
} func (c *MultiConn) WriteMsg(buf []byte) error {
bc := c.BestConn()
if bc == nil {
return errors.New("no best connection")
}
return bc.WriteMsg(buf)
}
// ReleaseMsg releases a buffer
func (c *MultiConn) ReleaseMsg(m []byte) {
for _, c := range c.getConns() {
c.ReleaseMsg(m)
} }
return nil
} }
...@@ -178,7 +178,7 @@ func TestMulticonnSend(t *testing.T) { ...@@ -178,7 +178,7 @@ func TestMulticonnSend(t *testing.T) {
for _, m := range msgs.msgs { for _, m := range msgs.msgs {
log.Info("send: %s", m.payload) log.Info("send: %s", m.payload)
c.Out() <- []byte(m.payload) c.WriteMsg([]byte(m.payload))
msgs.Sent(t, m.payload) msgs.Sent(t, m.payload)
<-time.After(time.Microsecond * 10) <-time.After(time.Microsecond * 10)
} }
...@@ -189,16 +189,20 @@ func TestMulticonnSend(t *testing.T) { ...@@ -189,16 +189,20 @@ func TestMulticonnSend(t *testing.T) {
for { for {
select { select {
case payload := <-c.In(): default:
msgs.Received(t, string(payload))
log.Info("recv: %s", payload)
if msgs.recv == len(msgs.msgs) {
return
}
case <-ctx.Done(): case <-ctx.Done():
return return
}
payload, err := c.ReadMsg()
if err != nil {
panic(err)
}
msgs.Received(t, string(payload))
log.Info("recv: %s", payload)
if msgs.recv == len(msgs.msgs) {
return
} }
} }
...@@ -252,11 +256,11 @@ func TestMulticonnSendUnderlying(t *testing.T) { ...@@ -252,11 +256,11 @@ func TestMulticonnSendUnderlying(t *testing.T) {
log.Info("send: %s", m.payload) log.Info("send: %s", m.payload)
switch i % 3 { switch i % 3 {
case 0: case 0:
conns[0].Out() <- []byte(m.payload) conns[0].WriteMsg([]byte(m.payload))
case 1: case 1:
conns[1].Out() <- []byte(m.payload) conns[1].WriteMsg([]byte(m.payload))
case 2: case 2:
c.Out() <- []byte(m.payload) c.WriteMsg([]byte(m.payload))
} }
msgs.Sent(t, m.payload) msgs.Sent(t, m.payload)
<-time.After(time.Microsecond * 10) <-time.After(time.Microsecond * 10)
...@@ -269,16 +273,20 @@ func TestMulticonnSendUnderlying(t *testing.T) { ...@@ -269,16 +273,20 @@ func TestMulticonnSendUnderlying(t *testing.T) {
for { for {
select { select {
case payload := <-c.In(): default:
msgs.Received(t, string(payload))
log.Info("recv: %s", payload)
if msgs.recv == len(msgs.msgs) {
return
}
case <-ctx.Done(): case <-ctx.Done():
return return
}
payload, err := c.ReadMsg()
if err != nil {
panic(err)
}
msgs.Received(t, string(payload))
log.Info("recv: %s", payload)
if msgs.recv == len(msgs.msgs) {
return
} }
} }
......
package conn package conn
import ( import (
"errors"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
msgio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio"
ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr"
spipe "github.com/jbenet/go-ipfs/crypto/spipe" secio "github.com/jbenet/go-ipfs/crypto/secio"
peer "github.com/jbenet/go-ipfs/peer" peer "github.com/jbenet/go-ipfs/peer"
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
"github.com/jbenet/go-ipfs/util/pipes"
) )
// secureConn wraps another Conn object with an encrypted channel. // secureConn wraps another Conn object with an encrypted channel.
...@@ -18,8 +16,11 @@ type secureConn struct { ...@@ -18,8 +16,11 @@ type secureConn struct {
// the wrapped conn // the wrapped conn
insecure Conn insecure Conn
// secure pipe, wrapping insecure // secure io (wrapping insecure)
secure *spipe.SecurePipe secure msgio.ReadWriteCloser
// secure Session
session secio.Session
ctxc.ContextCloser ctxc.ContextCloser
} }
...@@ -27,74 +28,30 @@ type secureConn struct { ...@@ -27,74 +28,30 @@ type secureConn struct {
// newConn constructs a new connection // newConn constructs a new connection
func newSecureConn(ctx context.Context, insecure Conn, peers peer.Peerstore) (Conn, error) { func newSecureConn(ctx context.Context, insecure Conn, peers peer.Peerstore) (Conn, error) {
// NewSession performs the secure handshake, which takes multiple RTT
sessgen := secio.SessionGenerator{Local: insecure.LocalPeer(), Peerstore: peers}
session, err := sessgen.NewSession(ctx, insecure)
if err != nil {
return nil, err
}
conn := &secureConn{ conn := &secureConn{
insecure: insecure, insecure: insecure,
session: session,
secure: session.ReadWriter(),
} }
conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close) conn.ContextCloser = ctxc.NewContextCloser(ctx, conn.close)
log.Debugf("newSecureConn: %v to %v handshake success!", conn.LocalPeer(), conn.RemotePeer())
log.Debugf("newSecureConn: %v to %v", insecure.LocalPeer(), insecure.RemotePeer())
// perform secure handshake before returning this connection.
if err := conn.secureHandshake(peers); err != nil {
conn.Close()
return nil, err
}
log.Debugf("newSecureConn: %v to %v handshake success!", insecure.LocalPeer(), insecure.RemotePeer())
return conn, nil return conn, nil
} }
// secureHandshake performs the spipe secure handshake.
func (c *secureConn) secureHandshake(peers peer.Peerstore) error {
if c.secure != nil {
return errors.New("Conn is already secured or being secured.")
}
// ok to panic here if this type assertion fails. Interface hack.
// when we support wrapping other Conns, we'll need to change
// spipe to do something else.
insecureSC := c.insecure.(*singleConn)
// setup a Duplex pipe for spipe
insecureD := pipes.Duplex{
In: insecureSC.msgio.incoming.MsgChan,
Out: insecureSC.msgio.outgoing.MsgChan,
}
// spipe performs the secure handshake, which takes multiple RTT
sp, err := spipe.NewSecurePipe(c.Context(), 10, c.LocalPeer(), peers, insecureD)
if err != nil {
return err
}
// assign it into the conn object
c.secure = sp
// if we do not know RemotePeer, get it from secure chan (who identifies it)
if insecureSC.remote == nil {
insecureSC.remote = c.secure.RemotePeer()
} else if insecureSC.remote != c.secure.RemotePeer() {
// this panic is here because this would be an insidious programmer error
// that we need to ensure we catch.
// update: this actually might happen under normal operation-- should
// perhaps return an error. TBD.
log.Errorf("secureConn peer mismatch. %v != %v", insecureSC.remote, c.secure.RemotePeer())
log.Errorf("insecureSC.remote: %s %#v", insecureSC.remote, insecureSC.remote)
log.Errorf("c.secure.LocalPeer: %s %#v", c.secure.RemotePeer(), c.secure.RemotePeer())
panic("secureConn peer mismatch. consructed incorrectly?")
}
return nil
}
// close is called by ContextCloser // close is called by ContextCloser
func (c *secureConn) close() error { func (c *secureConn) close() error {
err := c.insecure.Close() if err := c.secure.Close(); err != nil {
if c.secure != nil { // may never have gotten here. c.insecure.Close()
err = c.secure.Close() return err
} }
return err return c.insecure.Close()
} }
// ID is an identifier unique to this connection. // ID is an identifier unique to this connection.
...@@ -118,24 +75,35 @@ func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { ...@@ -118,24 +75,35 @@ func (c *secureConn) RemoteMultiaddr() ma.Multiaddr {
// LocalPeer is the Peer on this side // LocalPeer is the Peer on this side
func (c *secureConn) LocalPeer() peer.Peer { func (c *secureConn) LocalPeer() peer.Peer {
return c.insecure.LocalPeer() return c.session.LocalPeer()
} }
// RemotePeer is the Peer on the remote side // RemotePeer is the Peer on the remote side
func (c *secureConn) RemotePeer() peer.Peer { func (c *secureConn) RemotePeer() peer.Peer {
return c.insecure.RemotePeer() return c.session.RemotePeer()
}
// Read reads data, net.Conn style
func (c *secureConn) Read(buf []byte) (int, error) {
return c.secure.Read(buf)
}
// Write writes data, net.Conn style
func (c *secureConn) Write(buf []byte) (int, error) {
return c.secure.Write(buf)
} }
// In returns a readable message channel // ReadMsg reads data, net.Conn style
func (c *secureConn) In() <-chan []byte { func (c *secureConn) ReadMsg() ([]byte, error) {
return c.secure.In return c.secure.ReadMsg()
} }
// Out returns a writable message channel // WriteMsg writes data, net.Conn style
func (c *secureConn) Out() chan<- []byte { func (c *secureConn) WriteMsg(buf []byte) error {
return c.secure.Out return c.secure.WriteMsg(buf)
} }
func (c *secureConn) GetError() error { // ReleaseMsg releases a buffer
return c.insecure.GetError() func (c *secureConn) ReleaseMsg(m []byte) {
c.secure.ReleaseMsg(m)
} }
...@@ -105,6 +105,8 @@ func TestSecureCancel(t *testing.T) { ...@@ -105,6 +105,8 @@ func TestSecureCancel(t *testing.T) {
} }
func TestSecureCloseLeak(t *testing.T) { func TestSecureCloseLeak(t *testing.T) {
// t.Skip("Skipping in favor of another test")
if testing.Short() { if testing.Short() {
t.SkipNow() t.SkipNow()
} }
...@@ -125,15 +127,21 @@ func TestSecureCloseLeak(t *testing.T) { ...@@ -125,15 +127,21 @@ func TestSecureCloseLeak(t *testing.T) {
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
b1 := []byte("beep") b1 := []byte("beep")
c1.Out() <- b1 c1.WriteMsg(b1)
b2 := <-c2.In() b2, err := c2.ReadMsg()
if err != nil {
panic(err)
}
if !bytes.Equal(b1, b2) { if !bytes.Equal(b1, b2) {
panic("bytes not equal") panic("bytes not equal")
} }
b2 = []byte("boop") b2 = []byte("beep")
c2.Out() <- b2 c2.WriteMsg(b2)
b1 = <-c1.In() b1, err = c1.ReadMsg()
if err != nil {
panic(err)
}
if !bytes.Equal(b1, b2) { if !bytes.Equal(b1, b2) {
panic("bytes not equal") panic("bytes not equal")
} }
......
...@@ -200,7 +200,10 @@ func (s *Swarm) fanOut() { ...@@ -200,7 +200,10 @@ func (s *Swarm) fanOut() {
log.Debugf("%s sent message to %s (%d)", s.local, msg.Peer(), i) log.Debugf("%s sent message to %s (%d)", s.local, msg.Peer(), i)
log.Event(context.TODO(), "sendMessage", s.local, msg) log.Event(context.TODO(), "sendMessage", s.local, msg)
// queue it in the connection's buffer // queue it in the connection's buffer
c.Out() <- msg.Data() if err := c.WriteMsg(msg.Data()); err != nil {
log.Infof("%s connection failed to write: %s", c, err)
continue
}
} }
} }
} }
...@@ -219,6 +222,9 @@ func (s *Swarm) fanInSingle(c conn.Conn) { ...@@ -219,6 +222,9 @@ func (s *Swarm) fanInSingle(c conn.Conn) {
c.Children().Done() // child of Conn as well. c.Children().Done() // child of Conn as well.
}() }()
// use readChan to be able to listen to Closing events
rchan := readChan(s.Context(), c)
i := 0 i := 0
for { for {
select { select {
...@@ -228,7 +234,7 @@ func (s *Swarm) fanInSingle(c conn.Conn) { ...@@ -228,7 +234,7 @@ func (s *Swarm) fanInSingle(c conn.Conn) {
case <-c.Closing(): // Conn closing case <-c.Closing(): // Conn closing
return return
case data, ok := <-c.In(): case data, ok := <-rchan:
if !ok { if !ok {
log.Infof("%s in channel closed", c) log.Infof("%s in channel closed", c)
return // channel closed. return // channel closed.
...@@ -240,26 +246,30 @@ func (s *Swarm) fanInSingle(c conn.Conn) { ...@@ -240,26 +246,30 @@ func (s *Swarm) fanInSingle(c conn.Conn) {
} }
} }
// Commenting out because it's platform specific // readChan is a temporary fixture to match the old interface. will be removed soon.
// func setSocketReuse(l manet.Listener) error { func readChan(ctx context.Context, c conn.Conn) <-chan []byte {
// nl := l.NetListener()
// ch := make(chan []byte) // no buffer. sync.
// // for now only TCP. TODO change this when more networks.
// file, err := nl.(*net.TCPListener).File() go func() {
// if err != nil { defer close(ch)
// return err
// } for {
// msg, err := c.ReadMsg()
// fd := file.Fd() if err != nil {
// err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) log.Infof("%s connection failed: %s", c, err)
// if err != nil { return
// return err }
// }
// select {
// err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1) case <-c.Closing():
// if err != nil { return
// return err case <-ctx.Done():
// } return
// case ch <- msg:
// return nil }
// } }
}()
return ch
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论