diff options
Diffstat (limited to 'irc/states.go')
-rw-r--r-- | irc/states.go | 625 |
1 files changed, 625 insertions, 0 deletions
diff --git a/irc/states.go b/irc/states.go new file mode 100644 index 0000000..177c68f --- /dev/null +++ b/irc/states.go @@ -0,0 +1,625 @@ +package irc + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + "strings" + "sync/atomic" + "time" +) + +type SASLClient interface { + Handshake() (mech string) + Respond(challenge string) (res string, err error) +} + +type SASLPlain struct { + Username string + Password string +} + +func (auth *SASLPlain) Handshake() (mech string) { + mech = "PLAIN" + return +} + +func (auth *SASLPlain) Respond(challenge string) (res string, err error) { + if challenge != "+" { + err = errors.New("Unexpected challenge") + return + } + + user := []byte(auth.Username) + pass := []byte(auth.Password) + payload := bytes.Join([][]byte{user, user, pass}, []byte{0}) + res = base64.StdEncoding.EncodeToString(payload) + + return +} + +var SupportedCapabilities = map[string]struct{}{ + "account-notify": {}, + "account-tag": {}, + "away-notify": {}, + "batch": {}, + "cap-notify": {}, + "echo-message": {}, + "extended-join": {}, + "invite-notify": {}, + "labeled-response": {}, + "message-tags": {}, + "multi-prefix": {}, + "server-time": {}, + "sasl": {}, + "setname": {}, + "userhost-in-names": {}, +} + +type ConnectionState int + +const ( + ConnStart ConnectionState = iota + ConnRegistered + ConnQuit +) + +type User struct { + Nick string + AwayMsg string +} + +type Channel struct { + Name string + Members map[string]string + Topic string + TopicWho string + TopicTime time.Time + Secret bool +} + +type action interface{} + +type ( + actionPrivMsg struct { + To string + Content string + } + + actionTyping struct { + To string + } + actionTypingStop struct { + To string + } +) + +type SessionParams struct { + Nickname string + Username string + RealName string + + Auth SASLClient +} + +type Session struct { + conn io.ReadWriteCloser + msgs chan Message + acts chan action + evts chan Event + + running atomic.Value // bool + state ConnectionState + typingStamps map[string]time.Time + + nick string + lNick string + user string + real string + acct string + host string + auth SASLClient + + mode string + motd string + + availableCaps map[string]string + enabledCaps map[string]struct{} + features map[string]string + + users map[string]User + channels map[string]Channel +} + +func NewSession(conn io.ReadWriteCloser, params SessionParams) (s Session, err error) { + s = Session{ + conn: conn, + msgs: make(chan Message, 10), + acts: make(chan action, 10), + evts: make(chan Event, 10), + typingStamps: map[string]time.Time{}, + nick: params.Nickname, + lNick: strings.ToLower(params.Nickname), + user: params.Username, + real: params.RealName, + auth: params.Auth, + availableCaps: map[string]string{}, + enabledCaps: map[string]struct{}{}, + features: map[string]string{}, + users: map[string]User{}, + channels: map[string]Channel{}, + } + + s.running.Store(true) + + err = s.send("CAP LS 302\r\nNICK %s\r\nUSER %s 0 * :%s\r\n", s.nick, s.user, s.real) + if err != nil { + return + } + + go func() { + r := bufio.NewScanner(conn) + + for r.Scan() { + line := r.Text() + //fmt.Println(" > ", line) + + msg, err := Tokenize(line) + if err != nil { + continue + } + + err = msg.Validate() + if err != nil { + continue + } + + s.msgs <- msg + } + + s.Stop() + }() + + go s.run() + + return +} + +func (s *Session) Running() bool { + return s.running.Load().(bool) +} + +func (s *Session) Stop() { + s.running.Store(false) + s.conn.Close() +} + +func (s *Session) Poll() (events <-chan Event) { + return s.evts +} + +func (s *Session) IsChannel(name string) bool { + return strings.IndexAny(name, "#&") == 0 // TODO compute CHANTYPES +} + +func (s *Session) PrivMsg(to, content string) { + s.acts <- actionPrivMsg{to, content} +} + +func (s *Session) privMsg(act actionPrivMsg) (err error) { + err = s.send("PRIVMSG %s :%s\r\n", act.To, act.Content) + return +} + +func (s *Session) Typing(to string) { + s.acts <- actionTyping{to} +} + +func (s *Session) typing(act actionTyping) (err error) { + if _, ok := s.enabledCaps["message-tags"]; !ok { + return + } + + to := strings.ToLower(act.To) + + if t, ok := s.typingStamps[to]; ok && time.Now().Sub(t) < 3 { + return + } + + err = s.send("@+typing=active TAGMSG %s\r\n", act.To) + return +} + +func (s *Session) TypingStop(to string) { + s.acts <- actionTypingStop{to} +} + +func (s *Session) typingStop(act actionTypingStop) (err error) { + if _, ok := s.enabledCaps["message-tags"]; !ok { + return + } + + err = s.send("@+typing=done TAGMSG %s\r\n", act.To) + return +} + +func (s *Session) run() { + for s.Running() { + var ( + ev Event + err error + ) + + select { + case act := <-s.acts: + switch act := act.(type) { + case actionPrivMsg: + err = s.privMsg(act) + case actionTyping: + err = s.typing(act) + case actionTypingStop: + err = s.typingStop(act) + } + case msg := <-s.msgs: + if s.state == ConnStart { + ev, err = s.handleStart(msg) + } else if s.state == ConnRegistered { + ev, err = s.handle(msg) + } + } + + if ev != nil { + s.evts <- ev + } + if err != nil { + s.evts <- err + } + } +} + +func (s *Session) handleStart(msg Message) (ev Event, err error) { + switch msg.Command { + case "AUTHENTICATE": + if s.auth != nil { + var res string + + res, err = s.auth.Respond(msg.Params[0]) + if err != nil { + err = s.send("AUTHENTICATE *\r\n") + return + } + + err = s.send("AUTHENTICATE %s\r\n", res) + if err != nil { + return + } + } + case "900": + err = s.send("CAP END\r\n") + if err != nil { + return + } + + s.acct = msg.Params[2] + _, _, s.host = FullMask(msg.Params[1]) + case "902", "904", "905", "906", "907", "908": + err = s.send("CAP END\r\n") + if err != nil { + return + } + case "CAP": + switch msg.Params[1] { + case "LS": + var willContinue bool + var ls string + + if msg.Params[2] == "*" { + willContinue = true + ls = msg.Params[3] + } else { + willContinue = false + ls = msg.Params[2] + } + + for _, c := range TokenizeCaps(ls) { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + if !willContinue { + var req strings.Builder + + for c := range s.availableCaps { + if _, ok := SupportedCapabilities[c]; !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c) + } + + _, ok := s.availableCaps["sasl"] + if s.auth == nil || !ok { + _, _ = fmt.Fprintf(&req, "CAP END\r\n") + } + + err = s.send(req.String()) + if err != nil { + return + } + } + case "ACK": + for _, c := range strings.Split(msg.Params[2], " ") { + s.enabledCaps[c] = struct{}{} + + if s.auth != nil && c == "sasl" { + h := s.auth.Handshake() + err = s.send("AUTHENTICATE %s\r\n", h) + if err != nil { + return + } + } + } + } + case "372": // RPL_MOTD + s.motd += "\n" + strings.TrimPrefix(msg.Params[1], "- ") + case "433": // ERR_NICKNAMEINUSE + s.nick = s.nick + "_" + + err = s.send("NICK %s\r\n", s.nick) + if err != nil { + return + } + default: + ev, err = s.handle(msg) + } + + return +} + +func (s *Session) handle(msg Message) (ev Event, err error) { + switch msg.Command { + case "001": // RPL_WELCOME + s.nick = msg.Params[0] + s.lNick = strings.ToLower(s.nick) + s.state = ConnRegistered + ev = RegisteredEvent{} + + if s.host == "" { + err = s.send("WHO %s\r\n", s.nick) + if err != nil { + return + } + } + case "005": // RPL_ISUPPORT + s.updateFeatures(msg.Params[1 : len(msg.Params)-1]) + case "352": // RPL_WHOREPLY + if s.lNick == strings.ToLower(msg.Params[5]) { + s.host = msg.Params[3] + } + case "CAP": + switch msg.Params[1] { + case "ACK": + for _, c := range strings.Split(msg.Params[2], " ") { + s.enabledCaps[c] = struct{}{} + } + case "NAK": + for _, c := range strings.Split(msg.Params[2], " ") { + delete(s.enabledCaps, c) + } + case "NEW": + diff := TokenizeCaps(msg.Params[2]) + + for _, c := range diff { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + var req strings.Builder + + for _, c := range diff { + _, ok := SupportedCapabilities[c.Name] + if !c.Enable || !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c.Name) + } + + _, ok := s.availableCaps["sasl"] + if s.acct == "" && ok { + // TODO authenticate + } + + err = s.send(req.String()) + if err != nil { + return + } + case "DEL": + diff := TokenizeCaps(msg.Params[2]) + + for i := range diff { + diff[i].Enable = !diff[i].Enable + } + + for _, c := range diff { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + var req strings.Builder + + for _, c := range diff { + _, ok := SupportedCapabilities[c.Name] + if !c.Enable || !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c.Name) + } + + _, ok := s.availableCaps["sasl"] + if s.acct == "" && ok { + // TODO authenticate + } + + err = s.send(req.String()) + if err != nil { + return + } + } + case "JOIN": + nick, _, _ := FullMask(msg.Prefix) + lNick := strings.ToLower(nick) + channel := strings.ToLower(msg.Params[0]) + channelEv := ChannelEvent{Channel: msg.Params[0]} + + if lNick == s.lNick { + s.channels[channel] = Channel{ + Name: msg.Params[0], + Members: map[string]string{}, + } + } else if c, ok := s.channels[channel]; ok { + if _, ok := s.users[lNick]; !ok { + s.users[lNick] = User{Nick: nick} + } + c.Members[lNick] = "" + ev = UserJoinEvent{ChannelEvent: channelEv, UserEvent: UserEvent{Nick: nick}} + } + case "353": // RPL_NAMREPLY + channel := strings.ToLower(msg.Params[2]) + + if c, ok := s.channels[channel]; ok { + c.Secret = msg.Params[1] == "@" + names := TokenizeNames(msg.Params[3], "~&@%+") // TODO compute prefixes + + for _, name := range names { + nick := name.Nick + lNick := strings.ToLower(nick) + + if _, ok := s.users[lNick]; !ok { + s.users[lNick] = User{Nick: nick} + } + c.Members[lNick] = name.PowerLevel + } + } + case "366": // RPL_ENDOFNAMES + ev = SelfJoinEvent{ChannelEvent{Channel: msg.Params[1]}} + case "332": // RPL_TOPIC + channel := strings.ToLower(msg.Params[1]) + + if c, ok := s.channels[channel]; ok { + c.Topic = msg.Params[2] + } + case "PRIVMSG": + nick, _, _ := FullMask(msg.Prefix) + target := strings.ToLower(msg.Params[0]) + + if target == s.lNick { + // PRIVMSG to self + t, ok := msg.Time() + if !ok { + t = time.Now() + } + ev = QueryMessageEvent{ + UserEvent: UserEvent{Nick: nick}, + Content: msg.Params[1], + Time: t, + } + } else if _, ok := s.channels[target]; ok { + // PRIVMSG to channel + t, ok := msg.Time() + if !ok { + t = time.Now() + } + ev = ChannelMessageEvent{ + UserEvent: UserEvent{Nick: nick}, + ChannelEvent: ChannelEvent{Channel: msg.Params[0]}, + Content: msg.Params[1], + Time: t, + } + } + case "PING": + err = s.send("PONG :%s\r\n", msg.Params[0]) + if err != nil { + return + } + case "ERROR": + err = errors.New("connection terminated") + if len(msg.Params) > 0 { + err = fmt.Errorf("connection terminated: %s", msg.Params[0]) + } + s.state = ConnQuit + default: + } + return +} + +func (s *Session) updateFeatures(features []string) { + for _, f := range features { + if f == "" || f == "-" || f == "=" || f == "-=" { + continue + } + + var ( + add bool + key string + value string + ) + + if strings.HasPrefix(f, "-") { + add = false + f = f[1:] + } else { + add = true + } + + kv := strings.SplitN(f, "=", 2) + key = strings.ToUpper(kv[0]) + if len(kv) > 1 { + value = kv[1] + } + + if add { + s.features[key] = value + } else { + delete(s.features, key) + } + } +} + +/* +func (cli *Session) send(format string, args ...interface{}) (err error) { + msg := fmt.Sprintf(format, args...) + + for _, line := range strings.Split(msg, "\r\n") { + if line != "" { + fmt.Println("< ", line) + } + } + + _, err = cli.conn.Write([]byte(msg)) + + return +} + +// */ + +//* +func (s *Session) send(format string, args ...interface{}) (err error) { + _, err = fmt.Fprintf(s.conn, format, args...) + return +} + +// */ |