Commit c4d0ac0e authored by Dave Cheney's avatar Dave Cheney Committed by Gustavo Niemeyer

exp/ssh: add Std{in,out,err}Pipe methods to Session

R=gustav.paul, cw, agl, rsc, n13m3y3r
CC=golang-dev
https://golang.org/cl/5433080
parent c0a53bbc
......@@ -54,9 +54,10 @@ type Session struct {
*clientChan // the channel backing this session
started bool // true once a Shell or Run is invoked.
copyFuncs []func() error
errch chan error // one send per copyFunc
started bool // true once Start, Run or Shell is invoked.
closeAfterWait []io.Closer
copyFuncs []func() error
errch chan error // one send per copyFunc
}
// RFC 4254 Section 6.4.
......@@ -231,7 +232,7 @@ func (s *Session) start() error {
return nil
}
// Wait waits for the remote command to exit.
// Wait waits for the remote command to exit.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
......@@ -244,11 +245,12 @@ func (s *Session) Wait() error {
copyError = err
}
}
for _, fd := range s.closeAfterWait {
fd.Close()
}
if waitErr != nil {
return waitErr
}
return copyError
}
......@@ -283,11 +285,15 @@ func (s *Session) stdin() error {
s.Stdin = new(bytes.Buffer)
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(&chanWriter{
w := &chanWriter{
packetWriter: s,
peersId: s.peersId,
win: s.win,
}, s.Stdin)
}
_, err := io.Copy(w, s.Stdin)
if err1 := w.Close(); err == nil {
err = err1
}
return err
})
return nil
......@@ -298,11 +304,12 @@ func (s *Session) stdout() error {
s.Stdout = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, &chanReader{
r := &chanReader{
packetWriter: s,
peersId: s.peersId,
data: s.data,
})
}
_, err := io.Copy(s.Stdout, r)
return err
})
return nil
......@@ -313,16 +320,72 @@ func (s *Session) stderr() error {
s.Stderr = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, &chanReader{
r := &chanReader{
packetWriter: s,
peersId: s.peersId,
data: s.dataExt,
})
}
_, err := io.Copy(s.Stderr, r)
return err
})
return nil
}
// StdinPipe returns a pipe that will be connected to the
// remote command's standard input when the command starts.
func (s *Session) StdinPipe() (io.WriteCloser, error) {
if s.Stdin != nil {
return nil, errors.New("ssh: Stdin already set")
}
if s.started {
return nil, errors.New("ssh: StdinPipe after process started")
}
pr, pw := io.Pipe()
s.Stdin = pr
s.closeAfterWait = append(s.closeAfterWait, pr)
return pw, nil
}
// StdoutPipe returns a pipe that will be connected to the
// remote command's standard output when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StdoutPipe reader is
// not serviced fast enought it may eventually cause the
// remote command to block.
func (s *Session) StdoutPipe() (io.ReadCloser, error) {
if s.Stdout != nil {
return nil, errors.New("ssh: Stdout already set")
}
if s.started {
return nil, errors.New("ssh: StdoutPipe after process started")
}
pr, pw := io.Pipe()
s.Stdout = pw
s.closeAfterWait = append(s.closeAfterWait, pw)
return pr, nil
}
// StderrPipe returns a pipe that will be connected to the
// remote command's standard error when the command starts.
// There is a fixed amount of buffering that is shared between
// stdout and stderr streams. If the StderrPipe reader is
// not serviced fast enought it may eventually cause the
// remote command to block.
func (s *Session) StderrPipe() (io.ReadCloser, error) {
if s.Stderr != nil {
return nil, errors.New("ssh: Stderr already set")
}
if s.started {
return nil, errors.New("ssh: StderrPipe after process started")
}
pr, pw := io.Pipe()
s.Stderr = pw
s.closeAfterWait = append(s.closeAfterWait, pw)
return pr, nil
}
// TODO(dfc) add Output and CombinedOutput helpers
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport)
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
"io"
"testing"
)
// dial constructs a new test server and returns a *ClientConn.
func dial(t *testing.T) *ClientConn {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
go func() {
defer l.Close()
conn, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept: %v", err)
return
}
defer conn.Close()
if err := conn.Handshake(); err != nil {
t.Errorf("Unable to handshake: %v", err)
return
}
for {
ch, err := conn.Accept()
if err == io.EOF {
return
}
if err != nil {
t.Errorf("Unable to accept incoming channel request: %v", err)
return
}
if ch.ChannelType() != "session" {
ch.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch.Accept()
go func() {
defer ch.Close()
// this string is returned to stdout
shell := NewServerShell(ch, "golang")
shell.ReadLine()
type exitMsg struct {
PeersId uint32
Request string
WantReply bool
Status uint32
}
// TODO(dfc) casting to the concrete type should not be
// necessary to send a packet.
msg := exitMsg{
PeersId: ch.(*channel).theirId,
Request: "exit-status",
WantReply: false,
Status: 0,
}
ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
}()
}
t.Log("done")
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
return c
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment