提交 061e1ab8 作者: Juan Batiz-Benet

net: better protocol headers

上级 c150668a
...@@ -46,7 +46,7 @@ type Conn interface { ...@@ -46,7 +46,7 @@ type Conn interface {
conn.PeerConn conn.PeerConn
// NewStream constructs a new Stream directly connected to p. // NewStream constructs a new Stream directly connected to p.
NewStream(p peer.Peer) (Stream, error) NewStream(pr ProtocolID, p peer.Peer) (Stream, error)
} }
// Network is the interface IPFS uses for connecting to the world. // Network is the interface IPFS uses for connecting to the world.
...@@ -63,7 +63,8 @@ type Network interface { ...@@ -63,7 +63,8 @@ type Network interface {
// NewStream returns a new stream to given peer p. // NewStream returns a new stream to given peer p.
// If there is no connection to p, attempts to create one. // If there is no connection to p, attempts to create one.
NewStream(p peer.Peer) (Stream, error) // If ProtocolID is "", writes no header.
NewStream(ProtocolID, peer.Peer) (Stream, error)
// Swarm returns the connection Swarm // Swarm returns the connection Swarm
Swarm() *swarm.Swarm Swarm() *swarm.Swarm
......
...@@ -37,31 +37,10 @@ type Mux struct { ...@@ -37,31 +37,10 @@ type Mux struct {
sync.RWMutex sync.RWMutex
} }
// NextName reads the stream and returns the next protocol name // ReadProtocolHeader reads the stream and returns the next Handler function
// according to the muxer encoding. // according to the muxer encoding.
func (m *Mux) NextName(s io.Reader) (string, error) { func (m *Mux) ReadProtocolHeader(s io.Reader) (string, StreamHandler, error) {
name, err := ReadLengthPrefix(s)
// c-string identifier
// the first byte is our length
l := make([]byte, 1)
if _, err := io.ReadFull(s, l); err != nil {
return "", err
}
length := int(l[0])
// the next are our identifier
name := make([]byte, length)
if _, err := io.ReadFull(s, name); err != nil {
return "", err
}
return string(name), nil
}
// NextHandler reads the stream and returns the next Handler function
// according to the muxer encoding.
func (m *Mux) NextHandler(s io.Reader) (string, StreamHandler, error) {
name, err := m.NextName(s)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
...@@ -92,7 +71,7 @@ func (m *Mux) SetHandler(p ProtocolID, h StreamHandler) { ...@@ -92,7 +71,7 @@ func (m *Mux) SetHandler(p ProtocolID, h StreamHandler) {
func (m *Mux) Handle(s Stream) { func (m *Mux) Handle(s Stream) {
ctx := context.Background() ctx := context.Background()
name, handler, err := m.NextHandler(s) name, handler, err := m.ReadProtocolHeader(s)
if err != nil { if err != nil {
err = fmt.Errorf("protocol mux error: %s", err) err = fmt.Errorf("protocol mux error: %s", err)
log.Error(err) log.Error(err)
...@@ -105,8 +84,27 @@ func (m *Mux) Handle(s Stream) { ...@@ -105,8 +84,27 @@ func (m *Mux) Handle(s Stream) {
handler(s) handler(s)
} }
// Write writes the name into Writer with a length-byte-prefix. // ReadLengthPrefix reads the name from Reader with a length-byte-prefix.
func Write(w io.Writer, name string) error { func ReadLengthPrefix(r io.Reader) (string, error) {
// c-string identifier
// the first byte is our length
l := make([]byte, 1)
if _, err := io.ReadFull(r, l); err != nil {
return "", err
}
length := int(l[0])
// the next are our identifier
name := make([]byte, length)
if _, err := io.ReadFull(r, name); err != nil {
return "", err
}
return string(name), nil
}
// WriteLengthPrefix writes the name into Writer with a length-byte-prefix.
func WriteLengthPrefix(w io.Writer, name string) error {
s := make([]byte, len(name)+1) s := make([]byte, len(name)+1)
s[0] = byte(len(name)) s[0] = byte(len(name))
copy(s[1:], []byte(name)) copy(s[1:], []byte(name))
......
...@@ -15,7 +15,7 @@ var testCases = map[string]string{ ...@@ -15,7 +15,7 @@ var testCases = map[string]string{
func TestWrite(t *testing.T) { func TestWrite(t *testing.T) {
for k, v := range testCases { for k, v := range testCases {
var buf bytes.Buffer var buf bytes.Buffer
Write(&buf, k) WriteLengthPrefix(&buf, k)
v2 := buf.Bytes() v2 := buf.Bytes()
if !bytes.Equal(v2, []byte(v)) { if !bytes.Equal(v2, []byte(v)) {
...@@ -48,7 +48,7 @@ func TestHandler(t *testing.T) { ...@@ -48,7 +48,7 @@ func TestHandler(t *testing.T) {
continue continue
} }
name, _, err := m.NextHandler(&buf) name, _, err := m.ReadProtocolHeader(&buf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
continue continue
......
...@@ -44,12 +44,20 @@ func (c *conn_) SwarmConn() *swarm.Conn { ...@@ -44,12 +44,20 @@ func (c *conn_) SwarmConn() *swarm.Conn {
return (*swarm.Conn)(c) return (*swarm.Conn)(c)
} }
func (c *conn_) NewStream(p peer.Peer) (Stream, error) { func (c *conn_) NewStream(pr ProtocolID, p peer.Peer) (Stream, error) {
s, err := (*swarm.Conn)(c).NewStream() s, err := (*swarm.Conn)(c).NewStream()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return (*stream)(s), nil
ss := (*stream)(s)
if err := writeProtocolHeader(pr, ss); err != nil {
ss.Close()
return nil, err
}
return ss, nil
} }
// LocalMultiaddr is the Multiaddr on this side // LocalMultiaddr is the Multiaddr on this side
...@@ -154,8 +162,36 @@ func (n *network) Connectedness(p peer.Peer) Connectedness { ...@@ -154,8 +162,36 @@ func (n *network) Connectedness(p peer.Peer) Connectedness {
return NotConnected return NotConnected
} }
// NewStream returns a new stream to given peer p.
// If there is no connection to p, attempts to create one.
// If ProtocolID is "", writes no header.
func (c *network) NewStreamWithPeer(pr ProtocolID, p peer.Peer) (Stream, error) {
s, err := c.swarm.NewStreamWithPeer(p)
if err != nil {
return nil, err
}
ss := (*stream)(s)
if err := writeProtocolHeader(pr, ss); err != nil {
ss.Close()
return nil, err
}
return ss, nil
}
// SetHandler sets the protocol handler on the Network's Muxer. // SetHandler sets the protocol handler on the Network's Muxer.
// This operation is threadsafe. // This operation is threadsafe.
func (n *network) SetHandler(p ProtocolID, h StreamHandler) { func (n *network) SetHandler(p ProtocolID, h StreamHandler) {
n.mux.SetHandler(p, h) n.mux.SetHandler(p, h)
} }
func writeProtocolHeader(pr ProtocolID, s Stream) error {
if pr != "" { // only write proper protocol headers
if err := WriteLengthPrefix(s, string(pr)); err != nil {
return err
}
}
return nil
}
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论