Commit 605e57d8 authored by Adam Langley's avatar Adam Langley

exp/ssh: new package.

The typical UNIX method for controlling long running process is to
send the process signals. Since this doesn't get you very far, various
ad-hoc, remote-control protocols have been used over time by programs
like Apache and BIND.

Implementing an SSH server means that Go code will have a standard,
secure way to do this in the future.

R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san
CC=golang-dev
https://golang.org/cl/4962064
parent b71a805c
# Copyright 2011 The Go Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
include ../../../Make.inc
TARG=exp/ssh
GOFILES=\
common.go\
messages.go\
server.go\
transport.go\
channel.go\
server_shell.go\
include ../../../Make.pkg
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"os"
"sync"
)
// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
// SSH connection.
type Channel interface {
// Accept accepts the channel creation request.
Accept() os.Error
// Reject rejects the channel creation request. After calling this, no
// other methods on the Channel may be called. If they are then the
// peer is likely to signal a protocol error and drop the connection.
Reject(reason RejectionReason, message string) os.Error
// Read may return a ChannelRequest as an os.Error.
Read(data []byte) (int, os.Error)
Write(data []byte) (int, os.Error)
Close() os.Error
// AckRequest either sends an ack or nack to the channel request.
AckRequest(ok bool) os.Error
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
// ExtraData returns the arbitary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
// ChannelRequest represents a request sent on a channel, outside of the normal
// stream of bytes. It may result from calling Read on a Channel.
type ChannelRequest struct {
Request string
WantReply bool
Payload []byte
}
func (c ChannelRequest) String() string {
return "channel request received"
}
// RejectionReason is an enumeration used when rejecting channel creation
// requests. See RFC 4254, section 5.1.
type RejectionReason int
const (
Prohibited RejectionReason = iota + 1
ConnectionFailed
UnknownChannelType
ResourceShortage
)
type channel struct {
// immutable once created
chanType string
extraData []byte
theyClosed bool
theySentEOF bool
weClosed bool
dead bool
serverConn *ServerConnection
myId, theirId uint32
myWindow, theirWindow uint32
maxPacketSize uint32
err os.Error
pendingRequests []ChannelRequest
pendingData []byte
head, length int
// This lock is inferior to serverConn.lock
lock sync.Mutex
cond *sync.Cond
}
func (c *channel) Accept() os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
confirm := channelOpenConfirmMsg{
PeersId: c.theirId,
MyId: c.myId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxPacketSize,
}
return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm))
}
func (c *channel) Reject(reason RejectionReason, message string) os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
reject := channelOpenFailureMsg{
PeersId: c.theirId,
Reason: uint32(reason),
Message: message,
Language: "en",
}
return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject))
}
func (c *channel) handlePacket(packet interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
switch packet := packet.(type) {
case *channelRequestMsg:
req := ChannelRequest{
Request: packet.Request,
WantReply: packet.WantReply,
Payload: packet.RequestSpecificData,
}
c.pendingRequests = append(c.pendingRequests, req)
c.cond.Signal()
case *channelCloseMsg:
c.theyClosed = true
c.cond.Signal()
case *channelEOFMsg:
c.theySentEOF = true
c.cond.Signal()
default:
panic("unknown packet type")
}
}
func (c *channel) handleData(data []byte) {
c.lock.Lock()
defer c.lock.Unlock()
// The other side should never send us more than our window.
if len(data)+c.length > len(c.pendingData) {
// TODO(agl): we should tear down the channel with a protocol
// error.
return
}
c.myWindow -= uint32(len(data))
for i := 0; i < 2; i++ {
tail := c.head + c.length
if tail > len(c.pendingData) {
tail -= len(c.pendingData)
}
n := copy(c.pendingData[tail:], data)
data = data[n:]
c.length += n
}
c.cond.Signal()
}
func (c *channel) Read(data []byte) (n int, err os.Error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.err != nil {
return 0, c.err
}
if c.myWindow <= uint32(len(c.pendingData))/2 {
packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
PeersId: c.theirId,
AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
})
if err := c.serverConn.out.writePacket(packet); err != nil {
return 0, err
}
}
for {
if c.theySentEOF || c.theyClosed || c.dead {
return 0, os.EOF
}
if len(c.pendingRequests) > 0 {
req := c.pendingRequests[0]
if len(c.pendingRequests) == 1 {
c.pendingRequests = nil
} else {
oldPendingRequests := c.pendingRequests
c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
copy(c.pendingRequests, oldPendingRequests[1:])
}
return 0, req
}
if c.length > 0 {
tail := c.head + c.length
if tail > len(c.pendingData) {
tail -= len(c.pendingData)
}
n = copy(data, c.pendingData[c.head:tail])
c.head += n
c.length -= n
if c.head == len(c.pendingData) {
c.head = 0
}
return
}
c.cond.Wait()
}
panic("unreachable")
}
func (c *channel) Write(data []byte) (n int, err os.Error) {
for len(data) > 0 {
c.lock.Lock()
if c.dead || c.weClosed {
return 0, os.EOF
}
if c.theirWindow == 0 {
c.cond.Wait()
continue
}
c.lock.Unlock()
todo := data
if uint32(len(todo)) > c.theirWindow {
todo = todo[:c.theirWindow]
}
packet := make([]byte, 1+4+4+len(todo))
packet[0] = msgChannelData
packet[1] = byte(c.theirId) >> 24
packet[2] = byte(c.theirId) >> 16
packet[3] = byte(c.theirId) >> 8
packet[4] = byte(c.theirId)
packet[5] = byte(len(todo)) >> 24
packet[6] = byte(len(todo)) >> 16
packet[7] = byte(len(todo)) >> 8
packet[8] = byte(len(todo))
copy(packet[9:], todo)
c.serverConn.lock.Lock()
if err = c.serverConn.out.writePacket(packet); err != nil {
c.serverConn.lock.Unlock()
return
}
c.serverConn.lock.Unlock()
n += len(todo)
data = data[len(todo):]
}
return
}
func (c *channel) Close() os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if c.weClosed {
return os.NewError("ssh: channel already closed")
}
c.weClosed = true
closeMsg := channelCloseMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg))
}
func (c *channel) AckRequest(ok bool) os.Error {
c.serverConn.lock.Lock()
defer c.serverConn.lock.Unlock()
if c.serverConn.err != nil {
return c.serverConn.err
}
if ok {
ack := channelRequestSuccessMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack))
} else {
ack := channelRequestFailureMsg{
PeersId: c.theirId,
}
return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack))
}
panic("unreachable")
}
func (c *channel) ChannelType() string {
return c.chanType
}
func (c *channel) ExtraData() []byte {
return c.extraData
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"strconv"
)
// These are string constants in the SSH protocol.
const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa"
cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96"
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
serviceSSH = "ssh-connection"
)
// UnexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
type UnexpectedMessageError struct {
expected, got uint8
}
func (u UnexpectedMessageError) String() string {
return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")"
}
// ParseError results from a malformed SSH message.
type ParseError struct {
msgType uint8
}
func (p ParseError) String() string {
return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
if clientAlgo == serverAlgo {
return clientAlgo, true
}
}
}
return
}
func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
return
}
hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
return
}
clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
ok = true
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package ssh implements an SSH server.
SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level
protocol is a remote shell and this is specifically implemented. However,
the multiplexed nature of SSH is exposed to users that wish to support
others.
An SSH server is represented by a Server, which manages a number of
ServerConnections and handles authentication.
var s Server
s.PubKeyCallback = pubKeyAuth
s.PasswordCallback = passwordAuth
pemBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
panic("Failed to load private key")
}
err = s.SetRSAPrivateKey(pemBytes)
if err != nil {
panic("Failed to parse private key")
}
Once a Server has been set up, connections can be attached.
var sConn ServerConnection
sConn.Server = &s
err = sConn.Handshake(conn)
if err != nil {
panic("failed to handshake")
}
An SSH connection multiplexes several channels, which must be accepted themselves:
for {
channel, err := sConn.Accept()
if err != nil {
panic("error from Accept")
}
...
}
Accept reads from the connection, demultiplexes packets to their corresponding
channels and returns when a new channel request is seen. Some goroutine must
always be calling Accept; otherwise no messages will be forwarded to the
channels.
Channels have a type, depending on the application level protocol intended. In
the case of a shell, the type is "session" and ServerShell may be used to
present a simple terminal interface.
if channel.ChannelType() != "session" {
c.Reject(RejectUnknownChannelType, "unknown channel type")
return
}
channel.Accept()
shell := NewServerShell(channel, "> ")
go func() {
defer channel.Close()
for {
line, err := shell.ReadLine()
if err != nil {
break
}
println(line)
}
return
}()
*/
package ssh
This diff is collapsed.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"big"
"rand"
"reflect"
"testing"
"testing/quick"
)
var intLengthTests = []struct {
val, length int
}{
{0, 4 + 0},
{1, 4 + 1},
{127, 4 + 1},
{128, 4 + 2},
{-1, 4 + 1},
}
func TestIntLength(t *testing.T) {
for _, test := range intLengthTests {
v := new(big.Int).SetInt64(int64(test.val))
length := intLength(v)
if length != test.length {
t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
}
}
}
var messageTypes = []interface{}{
&kexInitMsg{},
&kexDHInitMsg{},
&serviceRequestMsg{},
&serviceAcceptMsg{},
&userAuthRequestMsg{},
&channelOpenMsg{},
&channelOpenConfirmMsg{},
&channelRequestMsg{},
&channelRequestSuccessMsg{},
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for i, iface := range messageTypes {
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Elem().Interface()
m2 := iface
marshaled := marshal(msgIgnore, m1)
if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
break
}
if !reflect.DeepEqual(v.Interface(), m2) {
t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
break
}
}
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())
}
}
func randomNameList(rand *rand.Rand) []string {
ret := make([]string, rand.Int31()&15)
for i := range ret {
s := make([]byte, 1+(rand.Int31()&15))
for j := range s {
s[j] = 'a' + uint8(rand.Int31()&15)
}
ret[i] = string(s)
}
return ret
}
func randomInt(rand *rand.Rand) *big.Int {
return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
}
func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
ki := &kexInitMsg{}
randomBytes(ki.Cookie[:], rand)
ki.KexAlgos = randomNameList(rand)
ki.ServerHostKeyAlgos = randomNameList(rand)
ki.CiphersClientServer = randomNameList(rand)
ki.CiphersServerClient = randomNameList(rand)
ki.MACsClientServer = randomNameList(rand)
ki.MACsServerClient = randomNameList(rand)
ki.CompressionClientServer = randomNameList(rand)
ki.CompressionServerClient = randomNameList(rand)
ki.LanguagesClientServer = randomNameList(rand)
ki.LanguagesServerClient = randomNameList(rand)
if rand.Int31()&1 == 1 {
ki.FirstKexFollows = true
}
return reflect.ValueOf(ki)
}
func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
dhi := &kexDHInitMsg{}
dhi.X = randomInt(rand)
return reflect.ValueOf(dhi)
}
This diff is collapsed.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"os"
)
// ServerShell contains the state for running a VT100 terminal that is capable
// of reading lines of input.
type ServerShell struct {
c Channel
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewServerShell runs a VT100 terminal on the given channel. prompt is a
// string that is written at the start of each input line. For example: "> ".
func NewServerShell(c Channel, prompt string) *ServerShell {
return &ServerShell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *ServerShell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *ServerShell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *ServerShell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *ServerShell) ReadLine() (line string, err os.Error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
n, err := ss.c.Read(readBuf)
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", os.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
if req, ok := err.(ChannelRequest); ok {
ok := false
switch req.Request {
case "pty-req":
ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
if !ok {
ss.termWidth = 80
ss.termHeight = 24
}
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any commands, only the default shell.
ok = false
}
case "env":
ok = true
}
if req.WantReply {
ss.c.AckRequest(ok)
}
} else {
return "", err
}
}
panic("unreachable")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
"os"
)
type MockChannel struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockChannel) Accept() os.Error {
return nil
}
func (c *MockChannel) Reject(RejectionReason, string) os.Error {
return nil
}
func (c *MockChannel) Read(data []byte) (n int, err os.Error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, os.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockChannel) Write(data []byte) (n int, err os.Error) {
c.received = append(c.received, data...)
return len(data), nil
}
func (c *MockChannel) Close() os.Error {
return nil
}
func (c *MockChannel) AckRequest(ok bool) os.Error {
return nil
}
func (c *MockChannel) ChannelType() string {
return ""
}
func (c *MockChannel) ExtraData() []byte {
return nil
}
func TestClose(t *testing.T) {
c := &MockChannel{}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != os.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
err os.Error
}{
{
"",
"",
os.EOF,
},
{
"\r",
"",
nil,
},
{
"foo\r",
"foo",
nil,
},
{
"a\x1b[Cb\r", // right
"ab",
nil,
},
{
"a\x1b[Db\r", // left
"ba",
nil,
},
{
"a\177b\r", // backspace
"b",
nil,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 0; j < len(test.in); j++ {
c := &MockChannel{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewServerShell(c, "> ")
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/subtle"
"hash"
"io"
"net"
"os"
)
// halfConnection represents one direction of an SSH connection. It maintains
// the cipher state needed to process messages.
type halfConnection struct {
// Only one of these two will be non-nil
in *bufio.Reader
out net.Conn
rand io.Reader
cipherAlgo string
macAlgo string
compressionAlgo string
paddingMultiple int
seqNum uint32
mac hash.Hash
cipher cipher.Stream
}
func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
var lengthBytes [5]byte
_, err = io.ReadFull(hc.in, lengthBytes[:])
if err != nil {
return
}
if hc.cipher != nil {
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
}
macSize := 0
if hc.mac != nil {
hc.mac.Reset()
var seqNumBytes [4]byte
seqNumBytes[0] = byte(hc.seqNum >> 24)
seqNumBytes[1] = byte(hc.seqNum >> 16)
seqNumBytes[2] = byte(hc.seqNum >> 8)
seqNumBytes[3] = byte(hc.seqNum)
hc.mac.Write(seqNumBytes[:])
hc.mac.Write(lengthBytes[:])
macSize = hc.mac.Size()
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
paddingLength := uint32(lengthBytes[4])
if length <= paddingLength+1 {
return nil, os.NewError("invalid packet length")
}
if length > maxPacketSize {
return nil, os.NewError("packet too large")
}
packet = make([]byte, length-1+uint32(macSize))
_, err = io.ReadFull(hc.in, packet)
if err != nil {
return nil, err
}
mac := packet[length-1:]
if hc.cipher != nil {
hc.cipher.XORKeyStream(packet, packet[:length-1])
}
if hc.mac != nil {
hc.mac.Write(packet[:length-1])
if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
return nil, os.NewError("ssh: MAC failure")
}
}
hc.seqNum++
packet = packet[:length-paddingLength-1]
return
}
func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
for {
packet, err := hc.readOnePacket()
if err != nil {
return nil, err
}
if packet[0] != msgIgnore && packet[0] != msgDebug {
return packet, nil
}
}
panic("unreachable")
}
func (hc *halfConnection) writePacket(packet []byte) os.Error {
paddingMultiple := hc.paddingMultiple
if paddingMultiple == 0 {
paddingMultiple = 8
}
paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
if paddingLength < 4 {
paddingLength += paddingMultiple
}
var lengthBytes [5]byte
length := len(packet) + 1 + paddingLength
lengthBytes[0] = byte(length >> 24)
lengthBytes[1] = byte(length >> 16)
lengthBytes[2] = byte(length >> 8)
lengthBytes[3] = byte(length)
lengthBytes[4] = byte(paddingLength)
var padding [32]byte
_, err := io.ReadFull(hc.rand, padding[:paddingLength])
if err != nil {
return err
}
if hc.mac != nil {
hc.mac.Reset()
var seqNumBytes [4]byte
seqNumBytes[0] = byte(hc.seqNum >> 24)
seqNumBytes[1] = byte(hc.seqNum >> 16)
seqNumBytes[2] = byte(hc.seqNum >> 8)
seqNumBytes[3] = byte(hc.seqNum)
hc.mac.Write(seqNumBytes[:])
hc.mac.Write(lengthBytes[:])
hc.mac.Write(packet)
hc.mac.Write(padding[:paddingLength])
}
if hc.cipher != nil {
hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
hc.cipher.XORKeyStream(packet, packet)
hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
}
_, err = hc.out.Write(lengthBytes[:])
if err != nil {
return err
}
_, err = hc.out.Write(packet)
if err != nil {
return err
}
_, err = hc.out.Write(padding[:paddingLength])
if err != nil {
return err
}
if hc.mac != nil {
_, err = hc.out.Write(hc.mac.Sum())
}
hc.seqNum++
return err
}
const (
serverKeys = iota
clientKeys
)
// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
h := hashFunc.New()
// We only support these algorithms for now.
if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
return os.NewError("ssh: setupServerKeys internal error")
}
blockSize := 16
keySize := 16
macKeySize := 20
var ivTag, keyTag, macKeyTag byte
if direction == serverKeys {
ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
} else {
ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
}
iv := make([]byte, blockSize)
key := make([]byte, keySize)
macKey := make([]byte, macKeySize)
generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
generateKeyMaterial(key, keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
if err != nil {
return err
}
hc.cipher = cipher.NewCTR(aes, iv)
hc.paddingMultiple = 16
return nil
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
var digestsSoFar []byte
for len(out) > 0 {
h.Reset()
h.Write(K)
h.Write(H)
if len(digestsSoFar) == 0 {
h.Write([]byte{tag})
h.Write(sessionId)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum()
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
// a given size.
type truncatingMAC struct {
length int
hmac hash.Hash
}
func (t truncatingMAC) Write(data []byte) (int, os.Error) {
return t.hmac.Write(data)
}
func (t truncatingMAC) Sum() []byte {
digest := t.hmac.Sum()
return digest[:t.length]
}
func (t truncatingMAC) Reset() {
t.hmac.Reset()
}
func (t truncatingMAC) Size() int {
return t.length
}
// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
// version string. In the event that the client is talking a different protocol
// we need to set a limit otherwise we will keep using more and more memory
// while searching for the end of the version handshake.
const maxVersionStringBytes = 1024
func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
versionString = make([]byte, 0, 64)
seenCR := false
forEachByte:
for len(versionString) < maxVersionStringBytes {
b, err := r.ReadByte()
if err != nil {
return
}
if !seenCR {
if b == '\r' {
seenCR = true
}
} else {
if b == '\n' {
ok = true
break forEachByte
} else {
seenCR = false
}
}
versionString = append(versionString, b)
}
if ok {
// We need to remove the CR from versionString
versionString = versionString[:len(versionString)-1]
}
return
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment