diff options
Diffstat (limited to 'irc')
-rw-r--r-- | irc/events.go | 52 | ||||
-rw-r--r-- | irc/states.go | 625 | ||||
-rw-r--r-- | irc/tokens.go | 327 |
3 files changed, 1004 insertions, 0 deletions
diff --git a/irc/events.go b/irc/events.go new file mode 100644 index 0000000..2cd5214 --- /dev/null +++ b/irc/events.go @@ -0,0 +1,52 @@ +package irc + +import ( + "strings" + "time" +) + +type Event interface{} + +type RegisteredEvent struct{} + +type UserEvent struct { + Nick string + User string + Host string +} + +func (u UserEvent) NickMapped() (nick string) { + nick = strings.ToLower(u.Nick) + return +} + +type ChannelEvent struct { + Channel string +} + +func (c ChannelEvent) ChannelMapped() (channel string) { + channel = strings.ToLower(c.Channel) + return +} + +type UserJoinEvent struct { + UserEvent + ChannelEvent +} + +type SelfJoinEvent struct { + ChannelEvent +} + +type QueryMessageEvent struct { + UserEvent + Content string + Time time.Time +} + +type ChannelMessageEvent struct { + UserEvent + ChannelEvent + Content string + Time time.Time +} diff --git a/irc/states.go b/irc/states.go new file mode 100644 index 0000000..177c68f --- /dev/null +++ b/irc/states.go @@ -0,0 +1,625 @@ +package irc + +import ( + "bufio" + "bytes" + "encoding/base64" + "errors" + "fmt" + "io" + "strings" + "sync/atomic" + "time" +) + +type SASLClient interface { + Handshake() (mech string) + Respond(challenge string) (res string, err error) +} + +type SASLPlain struct { + Username string + Password string +} + +func (auth *SASLPlain) Handshake() (mech string) { + mech = "PLAIN" + return +} + +func (auth *SASLPlain) Respond(challenge string) (res string, err error) { + if challenge != "+" { + err = errors.New("Unexpected challenge") + return + } + + user := []byte(auth.Username) + pass := []byte(auth.Password) + payload := bytes.Join([][]byte{user, user, pass}, []byte{0}) + res = base64.StdEncoding.EncodeToString(payload) + + return +} + +var SupportedCapabilities = map[string]struct{}{ + "account-notify": {}, + "account-tag": {}, + "away-notify": {}, + "batch": {}, + "cap-notify": {}, + "echo-message": {}, + "extended-join": {}, + "invite-notify": {}, + "labeled-response": {}, + "message-tags": {}, + "multi-prefix": {}, + "server-time": {}, + "sasl": {}, + "setname": {}, + "userhost-in-names": {}, +} + +type ConnectionState int + +const ( + ConnStart ConnectionState = iota + ConnRegistered + ConnQuit +) + +type User struct { + Nick string + AwayMsg string +} + +type Channel struct { + Name string + Members map[string]string + Topic string + TopicWho string + TopicTime time.Time + Secret bool +} + +type action interface{} + +type ( + actionPrivMsg struct { + To string + Content string + } + + actionTyping struct { + To string + } + actionTypingStop struct { + To string + } +) + +type SessionParams struct { + Nickname string + Username string + RealName string + + Auth SASLClient +} + +type Session struct { + conn io.ReadWriteCloser + msgs chan Message + acts chan action + evts chan Event + + running atomic.Value // bool + state ConnectionState + typingStamps map[string]time.Time + + nick string + lNick string + user string + real string + acct string + host string + auth SASLClient + + mode string + motd string + + availableCaps map[string]string + enabledCaps map[string]struct{} + features map[string]string + + users map[string]User + channels map[string]Channel +} + +func NewSession(conn io.ReadWriteCloser, params SessionParams) (s Session, err error) { + s = Session{ + conn: conn, + msgs: make(chan Message, 10), + acts: make(chan action, 10), + evts: make(chan Event, 10), + typingStamps: map[string]time.Time{}, + nick: params.Nickname, + lNick: strings.ToLower(params.Nickname), + user: params.Username, + real: params.RealName, + auth: params.Auth, + availableCaps: map[string]string{}, + enabledCaps: map[string]struct{}{}, + features: map[string]string{}, + users: map[string]User{}, + channels: map[string]Channel{}, + } + + s.running.Store(true) + + err = s.send("CAP LS 302\r\nNICK %s\r\nUSER %s 0 * :%s\r\n", s.nick, s.user, s.real) + if err != nil { + return + } + + go func() { + r := bufio.NewScanner(conn) + + for r.Scan() { + line := r.Text() + //fmt.Println(" > ", line) + + msg, err := Tokenize(line) + if err != nil { + continue + } + + err = msg.Validate() + if err != nil { + continue + } + + s.msgs <- msg + } + + s.Stop() + }() + + go s.run() + + return +} + +func (s *Session) Running() bool { + return s.running.Load().(bool) +} + +func (s *Session) Stop() { + s.running.Store(false) + s.conn.Close() +} + +func (s *Session) Poll() (events <-chan Event) { + return s.evts +} + +func (s *Session) IsChannel(name string) bool { + return strings.IndexAny(name, "#&") == 0 // TODO compute CHANTYPES +} + +func (s *Session) PrivMsg(to, content string) { + s.acts <- actionPrivMsg{to, content} +} + +func (s *Session) privMsg(act actionPrivMsg) (err error) { + err = s.send("PRIVMSG %s :%s\r\n", act.To, act.Content) + return +} + +func (s *Session) Typing(to string) { + s.acts <- actionTyping{to} +} + +func (s *Session) typing(act actionTyping) (err error) { + if _, ok := s.enabledCaps["message-tags"]; !ok { + return + } + + to := strings.ToLower(act.To) + + if t, ok := s.typingStamps[to]; ok && time.Now().Sub(t) < 3 { + return + } + + err = s.send("@+typing=active TAGMSG %s\r\n", act.To) + return +} + +func (s *Session) TypingStop(to string) { + s.acts <- actionTypingStop{to} +} + +func (s *Session) typingStop(act actionTypingStop) (err error) { + if _, ok := s.enabledCaps["message-tags"]; !ok { + return + } + + err = s.send("@+typing=done TAGMSG %s\r\n", act.To) + return +} + +func (s *Session) run() { + for s.Running() { + var ( + ev Event + err error + ) + + select { + case act := <-s.acts: + switch act := act.(type) { + case actionPrivMsg: + err = s.privMsg(act) + case actionTyping: + err = s.typing(act) + case actionTypingStop: + err = s.typingStop(act) + } + case msg := <-s.msgs: + if s.state == ConnStart { + ev, err = s.handleStart(msg) + } else if s.state == ConnRegistered { + ev, err = s.handle(msg) + } + } + + if ev != nil { + s.evts <- ev + } + if err != nil { + s.evts <- err + } + } +} + +func (s *Session) handleStart(msg Message) (ev Event, err error) { + switch msg.Command { + case "AUTHENTICATE": + if s.auth != nil { + var res string + + res, err = s.auth.Respond(msg.Params[0]) + if err != nil { + err = s.send("AUTHENTICATE *\r\n") + return + } + + err = s.send("AUTHENTICATE %s\r\n", res) + if err != nil { + return + } + } + case "900": + err = s.send("CAP END\r\n") + if err != nil { + return + } + + s.acct = msg.Params[2] + _, _, s.host = FullMask(msg.Params[1]) + case "902", "904", "905", "906", "907", "908": + err = s.send("CAP END\r\n") + if err != nil { + return + } + case "CAP": + switch msg.Params[1] { + case "LS": + var willContinue bool + var ls string + + if msg.Params[2] == "*" { + willContinue = true + ls = msg.Params[3] + } else { + willContinue = false + ls = msg.Params[2] + } + + for _, c := range TokenizeCaps(ls) { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + if !willContinue { + var req strings.Builder + + for c := range s.availableCaps { + if _, ok := SupportedCapabilities[c]; !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c) + } + + _, ok := s.availableCaps["sasl"] + if s.auth == nil || !ok { + _, _ = fmt.Fprintf(&req, "CAP END\r\n") + } + + err = s.send(req.String()) + if err != nil { + return + } + } + case "ACK": + for _, c := range strings.Split(msg.Params[2], " ") { + s.enabledCaps[c] = struct{}{} + + if s.auth != nil && c == "sasl" { + h := s.auth.Handshake() + err = s.send("AUTHENTICATE %s\r\n", h) + if err != nil { + return + } + } + } + } + case "372": // RPL_MOTD + s.motd += "\n" + strings.TrimPrefix(msg.Params[1], "- ") + case "433": // ERR_NICKNAMEINUSE + s.nick = s.nick + "_" + + err = s.send("NICK %s\r\n", s.nick) + if err != nil { + return + } + default: + ev, err = s.handle(msg) + } + + return +} + +func (s *Session) handle(msg Message) (ev Event, err error) { + switch msg.Command { + case "001": // RPL_WELCOME + s.nick = msg.Params[0] + s.lNick = strings.ToLower(s.nick) + s.state = ConnRegistered + ev = RegisteredEvent{} + + if s.host == "" { + err = s.send("WHO %s\r\n", s.nick) + if err != nil { + return + } + } + case "005": // RPL_ISUPPORT + s.updateFeatures(msg.Params[1 : len(msg.Params)-1]) + case "352": // RPL_WHOREPLY + if s.lNick == strings.ToLower(msg.Params[5]) { + s.host = msg.Params[3] + } + case "CAP": + switch msg.Params[1] { + case "ACK": + for _, c := range strings.Split(msg.Params[2], " ") { + s.enabledCaps[c] = struct{}{} + } + case "NAK": + for _, c := range strings.Split(msg.Params[2], " ") { + delete(s.enabledCaps, c) + } + case "NEW": + diff := TokenizeCaps(msg.Params[2]) + + for _, c := range diff { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + var req strings.Builder + + for _, c := range diff { + _, ok := SupportedCapabilities[c.Name] + if !c.Enable || !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c.Name) + } + + _, ok := s.availableCaps["sasl"] + if s.acct == "" && ok { + // TODO authenticate + } + + err = s.send(req.String()) + if err != nil { + return + } + case "DEL": + diff := TokenizeCaps(msg.Params[2]) + + for i := range diff { + diff[i].Enable = !diff[i].Enable + } + + for _, c := range diff { + if c.Enable { + s.availableCaps[c.Name] = c.Value + } else { + delete(s.availableCaps, c.Name) + } + } + + var req strings.Builder + + for _, c := range diff { + _, ok := SupportedCapabilities[c.Name] + if !c.Enable || !ok { + continue + } + + _, _ = fmt.Fprintf(&req, "CAP REQ %s\r\n", c.Name) + } + + _, ok := s.availableCaps["sasl"] + if s.acct == "" && ok { + // TODO authenticate + } + + err = s.send(req.String()) + if err != nil { + return + } + } + case "JOIN": + nick, _, _ := FullMask(msg.Prefix) + lNick := strings.ToLower(nick) + channel := strings.ToLower(msg.Params[0]) + channelEv := ChannelEvent{Channel: msg.Params[0]} + + if lNick == s.lNick { + s.channels[channel] = Channel{ + Name: msg.Params[0], + Members: map[string]string{}, + } + } else if c, ok := s.channels[channel]; ok { + if _, ok := s.users[lNick]; !ok { + s.users[lNick] = User{Nick: nick} + } + c.Members[lNick] = "" + ev = UserJoinEvent{ChannelEvent: channelEv, UserEvent: UserEvent{Nick: nick}} + } + case "353": // RPL_NAMREPLY + channel := strings.ToLower(msg.Params[2]) + + if c, ok := s.channels[channel]; ok { + c.Secret = msg.Params[1] == "@" + names := TokenizeNames(msg.Params[3], "~&@%+") // TODO compute prefixes + + for _, name := range names { + nick := name.Nick + lNick := strings.ToLower(nick) + + if _, ok := s.users[lNick]; !ok { + s.users[lNick] = User{Nick: nick} + } + c.Members[lNick] = name.PowerLevel + } + } + case "366": // RPL_ENDOFNAMES + ev = SelfJoinEvent{ChannelEvent{Channel: msg.Params[1]}} + case "332": // RPL_TOPIC + channel := strings.ToLower(msg.Params[1]) + + if c, ok := s.channels[channel]; ok { + c.Topic = msg.Params[2] + } + case "PRIVMSG": + nick, _, _ := FullMask(msg.Prefix) + target := strings.ToLower(msg.Params[0]) + + if target == s.lNick { + // PRIVMSG to self + t, ok := msg.Time() + if !ok { + t = time.Now() + } + ev = QueryMessageEvent{ + UserEvent: UserEvent{Nick: nick}, + Content: msg.Params[1], + Time: t, + } + } else if _, ok := s.channels[target]; ok { + // PRIVMSG to channel + t, ok := msg.Time() + if !ok { + t = time.Now() + } + ev = ChannelMessageEvent{ + UserEvent: UserEvent{Nick: nick}, + ChannelEvent: ChannelEvent{Channel: msg.Params[0]}, + Content: msg.Params[1], + Time: t, + } + } + case "PING": + err = s.send("PONG :%s\r\n", msg.Params[0]) + if err != nil { + return + } + case "ERROR": + err = errors.New("connection terminated") + if len(msg.Params) > 0 { + err = fmt.Errorf("connection terminated: %s", msg.Params[0]) + } + s.state = ConnQuit + default: + } + return +} + +func (s *Session) updateFeatures(features []string) { + for _, f := range features { + if f == "" || f == "-" || f == "=" || f == "-=" { + continue + } + + var ( + add bool + key string + value string + ) + + if strings.HasPrefix(f, "-") { + add = false + f = f[1:] + } else { + add = true + } + + kv := strings.SplitN(f, "=", 2) + key = strings.ToUpper(kv[0]) + if len(kv) > 1 { + value = kv[1] + } + + if add { + s.features[key] = value + } else { + delete(s.features, key) + } + } +} + +/* +func (cli *Session) send(format string, args ...interface{}) (err error) { + msg := fmt.Sprintf(format, args...) + + for _, line := range strings.Split(msg, "\r\n") { + if line != "" { + fmt.Println("< ", line) + } + } + + _, err = cli.conn.Write([]byte(msg)) + + return +} + +// */ + +//* +func (s *Session) send(format string, args ...interface{}) (err error) { + _, err = fmt.Fprintf(s.conn, format, args...) + return +} + +// */ diff --git a/irc/tokens.go b/irc/tokens.go new file mode 100644 index 0000000..976207d --- /dev/null +++ b/irc/tokens.go @@ -0,0 +1,327 @@ +package irc + +import ( + "errors" + "fmt" + "strings" + "time" +) + +func word(s string) (w, rest string) { + split := strings.SplitN(s, " ", 2) + + if len(split) < 2 { + w = split[0] + rest = "" + } else { + w = split[0] + rest = split[1] + } + + return +} + +func tagEscape(c rune) (escape rune) { + switch c { + case ':': + escape = ';' + case 's': + escape = ' ' + case 'r': + escape = '\r' + case 'n': + escape = '\n' + default: + escape = c + } + + return +} + +func unescapeTagValue(escaped string) (unescaped string) { + var builder strings.Builder + builder.Grow(len(escaped)) + escape := false + + for _, c := range escaped { + if c == '\\' && !escape { + escape = true + } else { + var cpp rune + + if escape { + cpp = tagEscape(c) + } else { + cpp = c + } + + builder.WriteRune(cpp) + escape = false + } + } + + unescaped = builder.String() + return +} + +func parseTags(s string) (tags map[string]string) { + s = s[1:] + tags = map[string]string{} + + for _, item := range strings.Split(s, ";") { + if item == "" || item == "=" || item == "+" || item == "+=" { + continue + } + + kv := strings.SplitN(item, "=", 2) + if len(kv) < 2 { + tags[kv[0]] = "" + } else { + tags[kv[0]] = unescapeTagValue(kv[1]) + } + } + + return +} + +var ( + errEmptyMessage = errors.New("empty message") + errIncompleteMessage = errors.New("message is incomplete") +) + +var ( + errNoPrefix = errors.New("missing prefix") + errNotEnoughParams = errors.New("not enough params") + errUnknownCommand = errors.New("unknown command") +) + +type Message struct { + Tags map[string]string + Prefix string + Command string + Params []string +} + +func Tokenize(line string) (msg Message, err error) { + line = strings.TrimLeft(line, " ") + if line == "" { + err = errEmptyMessage + return + } + + if line[0] == '@' { + var tags string + + tags, line = word(line) + msg.Tags = parseTags(tags) + } + + line = strings.TrimLeft(line, " ") + if line == "" { + err = errIncompleteMessage + return + } + + if line[0] == ':' { + var prefix string + + prefix, line = word(line) + msg.Prefix = prefix[1:] + } + + line = strings.TrimLeft(line, " ") + if line == "" { + err = errIncompleteMessage + return + } + + msg.Command, line = word(line) + msg.Command = strings.ToUpper(msg.Command) + + msg.Params = make([]string, 0, 15) + for line != "" { + if line[0] == ':' { + msg.Params = append(msg.Params, line[1:]) + break + } + + var param string + param, line = word(line) + msg.Params = append(msg.Params, param) + } + + return +} + +func (msg *Message) Validate() (err error) { + switch msg.Command { + case "001": + if len(msg.Params) < 1 { + err = errNotEnoughParams + } + case "005": + if len(msg.Params) < 3 { + err = errNotEnoughParams + } + case "352": + if len(msg.Params) < 8 { + err = errNotEnoughParams + } + case "372": + if len(msg.Params) < 2 { + err = errNotEnoughParams + } + case "AUTHENTICATE": + if len(msg.Params) < 1 { + err = errNotEnoughParams + } + case "900": + if len(msg.Params) < 3 { + err = errNotEnoughParams + } + case "901": + if len(msg.Params) < 2 { + err = errNotEnoughParams + } + case "CAP": + if len(msg.Params) < 3 { + err = errNotEnoughParams + } else if msg.Params[1] == "LS" { + } else if msg.Params[1] == "LIST" { + } else if msg.Params[1] == "ACK" { + } else if msg.Params[1] == "NAK" { + } else if msg.Params[1] == "NEW" { + } else if msg.Params[1] == "DEL" { + } else { + err = errUnknownCommand + } + case "JOIN": + if len(msg.Params) < 1 { + err = errNotEnoughParams + } else if msg.Prefix == "" { + err = errNoPrefix + } + case "353": + if len(msg.Params) < 4 { + err = errNotEnoughParams + } + case "332": + if len(msg.Params) < 3 { + err = errNotEnoughParams + } + case "TOPIC": + if len(msg.Params) < 2 { + err = errNotEnoughParams + } + case "PING": + if len(msg.Params) < 1 { + err = errNotEnoughParams + } + case "PONG": + if len(msg.Params) < 1 { + err = errNotEnoughParams + } + default: + } + return +} + +func (msg *Message) Time() (t time.Time, ok bool) { + var tag string + var year, month, day, hour, minute, second, millis int + + tag, ok = msg.Tags["time"] + if !ok { + return + } + + tag = strings.TrimSuffix(tag, "Z") + + _, err := fmt.Sscanf(tag, "%4d-%2d-%2dT%2d:%2d:%2d.%3d", &year, &month, &day, &hour, &minute, &second, &millis) + if err != nil || month < 1 || 12 < month { + ok = false + return + } + + t = time.Date(year, time.Month(month), day, hour, minute, second, millis*1000000, time.Local) + + return +} + +func FullMask(s string) (nick, user, host string) { + if s == "" { + return + } + + spl0 := strings.Split(s, "@") + if 1 < len(spl0) { + host = spl0[1] + } + + spl1 := strings.Split(spl0[0], "!") + if 1 < len(spl1) { + user = spl1[1] + } + + nick = spl1[0] + + return +} + +type Cap struct { + Name string + Value string + Enable bool +} + +func TokenizeCaps(caps string) (diff []Cap) { + for _, c := range strings.Split(caps, " ") { + if c == "" || c == "-" || c == "=" || c == "-=" { + continue + } + + var item Cap + + if strings.HasPrefix(c, "-") { + item.Enable = false + c = c[1:] + } else { + item.Enable = true + } + + kv := strings.SplitN(c, "=", 2) + item.Name = strings.ToLower(kv[0]) + if len(kv) > 1 { + item.Value = kv[1] + } + + diff = append(diff, item) + } + + return +} + +type Name struct { + PowerLevel string + Nick string + User string + Host string +} + +func TokenizeNames(trailing string, prefixes string) (names []Name) { + for _, name := range strings.Split(trailing, " ") { + if name == "" { + continue + } + + var item Name + + mask := strings.TrimLeft(name, prefixes) + item.Nick, item.User, item.Host = FullMask(mask) + item.PowerLevel = name[:len(name)-len(mask)] + + names = append(names, item) + } + + return +} |