Commit 9628d4c6 authored by Firmansyah Adiputra's avatar Firmansyah Adiputra Committed by Nigel Tao

Add authentication.

Other code fixing:
- Fixed bugs in get32.
- Fix code for parsing display string (as a new function).
- Fix code for connecting to X server. The old code only work
  if the server is listening to TCP port, otherwise it doesn't
  work (at least in my PC).

R=nigeltao_golang, rsc, jhh
CC=golang-dev
https://golang.org/cl/183111
parent f43d95fb
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xgb
import (
"bufio"
"io"
"os"
)
func getU16BE(r io.Reader, b []byte) (uint16, os.Error) {
_, err := io.ReadFull(r, b[0:2])
if err != nil {
return 0, err
}
return uint16(b[0])<<8 + uint16(b[1]), nil
}
func getBytes(r io.Reader, b []byte) ([]byte, os.Error) {
n, err := getU16BE(r, b)
if err != nil {
return nil, err
}
if int(n) > len(b) {
return nil, os.NewError("bytes too long for buffer")
}
_, err = io.ReadFull(r, b[0:n])
if err != nil {
return nil, err
}
return b[0:n], nil
}
func getString(r io.Reader, b []byte) (string, os.Error) {
b, err := getBytes(r, b)
if err != nil {
return "", err
}
return string(b), nil
}
// readAuthority reads the X authority file for the DISPLAY.
// If hostname == "" or hostname == "localhost",
// readAuthority uses the system's hostname (as returned by os.Hostname) instead.
func readAuthority(hostname, display string) (name string, data []byte, err os.Error) {
// b is a scratch buffer to use and should be at least 256 bytes long
// (i.e. it should be able to hold a hostname).
var b [256]byte
// As per /usr/include/X11/Xauth.h.
const familyLocal = 256
if len(hostname) == 0 || hostname == "localhost" {
hostname, err = os.Hostname()
if err != nil {
return "", nil, err
}
}
fname := os.Getenv("XAUTHORITY")
if len(fname) == 0 {
home := os.Getenv("HOME")
if len(home) == 0 {
err = os.NewError("Xauthority not found: $XAUTHORITY, $HOME not set")
return "", nil, err
}
fname = home + "/.Xauthority"
}
r, err := os.Open(fname, os.O_RDONLY, 0444)
if err != nil {
return "", nil, err
}
defer r.Close()
br := bufio.NewReader(r)
for {
family, err := getU16BE(br, b[0:2])
if err != nil {
return "", nil, err
}
addr, err := getString(br, b[0:])
if err != nil {
return "", nil, err
}
disp, err := getString(br, b[0:])
if err != nil {
return "", nil, err
}
name0, err := getString(br, b[0:])
if err != nil {
return "", nil, err
}
data0, err := getBytes(br, b[0:])
if err != nil {
return "", nil, err
}
if family == familyLocal && addr == hostname && disp == display {
return name0, data0, nil
}
}
panic("unreachable")
}
......@@ -18,12 +18,14 @@ import (
// A Conn represents a connection to an X server.
// Only one goroutine should use a Conn's methods at a time.
type Conn struct {
host string
conn net.Conn
nextId Id
nextCookie Cookie
replies map[Cookie][]byte
events queue
err os.Error
display string
defaultScreen int
scratch [32]byte
Setup SetupInfo
......@@ -89,7 +91,7 @@ func get32(buf []byte) uint32 {
v := uint32(buf[0])
v |= uint32(buf[1]) << 8
v |= uint32(buf[2]) << 16
v |= uint32(buf[3]) << 32
v |= uint32(buf[3]) << 24
return v
}
......@@ -284,49 +286,39 @@ func (c *Conn) PollForEvent() (Event, os.Error) {
}
// Dial connects to the X server given in the 'display' string.
// The display string is typically taken from os.Getenv("DISPLAY").
// If 'display' is empty it will be taken from os.Getenv("DISPLAY").
//
// Examples:
// Dial(":1") // connect to net.Dial("unix", "", "/tmp/.X11-unix/X1")
// Dial("/tmp/launch-123/:0") // connect to net.Dial("unix", "", "/tmp/launch-123/:0")
// Dial("hostname:2.1") // connect to net.Dial("tcp", "", "hostname:6002")
// Dial("tcp/hostname:1.0") // connect to net.Dial("tcp", "", "hostname:6001")
func Dial(display string) (*Conn, os.Error) {
var err os.Error
c := new(Conn)
c, err := connect(display)
if err != nil {
return nil, err
}
if display[0] == '/' {
c.conn, err = net.Dial("unix", "", display)
if err != nil {
fmt.Printf("cannot connect: %v\n", err)
return nil, err
}
} else {
parts := strings.Split(display, ":", 2)
host := parts[0]
port := 0
if len(parts) > 1 {
parts = strings.Split(parts[1], ".", 2)
port, _ = strconv.Atoi(parts[0])
if len(parts) > 1 {
c.defaultScreen, _ = strconv.Atoi(parts[1])
}
}
display = fmt.Sprintf("%s:%d", host, port+6000)
c.conn, err = net.Dial("tcp", "", display)
if err != nil {
fmt.Printf("cannot connect: %v\n", err)
return nil, err
}
// Get authentication data
authName, authData, err := readAuthority(c.host, c.display)
if err != nil {
return nil, err
}
// TODO: get these from .Xauthority
var authName, authData []byte
// Assume that the authentication protocol is "MIT-MAGIC-COOKIE-1".
if authName != "MIT-MAGIC-COOKIE-1" || len(authData) != 16 {
return nil, os.NewError("unsupported auth protocol " + authName)
}
buf := make([]byte, 12+pad(len(authName))+pad(len(authData)))
buf[0] = 'l'
buf[0] = 0x6c
buf[1] = 0
put16(buf[2:], 11)
put16(buf[4:], 0)
put16(buf[6:], uint16(len(authName)))
put16(buf[8:], uint16(len(authData)))
put16(buf[10:], 0)
copy(buf[12:], authName)
copy(buf[12:], strings.Bytes(authName))
copy(buf[12+pad(len(authName)):], authData)
if _, err = c.conn.Write(buf); err != nil {
return nil, err
......@@ -396,3 +388,77 @@ func getClientMessageData(b []byte, v *ClientMessageData) int {
}
return 20
}
func connect(display string) (*Conn, os.Error) {
if len(display) == 0 {
display = os.Getenv("DISPLAY")
}
display0 := display
if len(display) == 0 {
return nil, os.NewError("empty display string")
}
colonIdx := strings.LastIndex(display, ":")
if colonIdx < 0 {
return nil, os.NewError("bad display string: " + display0)
}
var protocol, socket string
c := new(Conn)
if display[0] == '/' {
socket = display[0:colonIdx]
} else {
slashIdx := strings.LastIndex(display, "/")
if slashIdx >= 0 {
protocol = display[0:slashIdx]
c.host = display[slashIdx+1 : colonIdx]
} else {
c.host = display[0:colonIdx]
}
}
display = display[colonIdx+1 : len(display)]
if len(display) == 0 {
return nil, os.NewError("bad display string: " + display0)
}
var scr string
dotIdx := strings.LastIndex(display, ".")
if dotIdx < 0 {
c.display = display[0:]
} else {
c.display = display[0:dotIdx]
scr = display[dotIdx+1:]
}
dispnum, err := strconv.Atoui(c.display)
if err != nil {
return nil, os.NewError("bad display string: " + display0)
}
if len(scr) != 0 {
c.defaultScreen, err = strconv.Atoi(scr)
if err != nil {
return nil, os.NewError("bad display string: " + display0)
}
}
// Connect to server
if len(socket) != 0 {
c.conn, err = net.Dial("unix", "", socket+":"+c.display)
} else if len(c.host) != 0 {
if protocol == "" {
protocol = "tcp"
}
c.conn, err = net.Dial(protocol, "", c.host+":"+strconv.Uitoa(6000+dispnum))
} else {
c.conn, err = net.Dial("unix", "", "/tmp/.X11-unix/X"+c.display)
}
if err != nil {
return nil, os.NewError("cannot connect to " + display0 + ": " + err.String())
}
return c, nil
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment