482 lines
12 KiB
Go
482 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"matrix-bot/bot"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/crypto/cryptohelper"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var homeserver string
|
|
var username string
|
|
var password string
|
|
var roomID string
|
|
var pickleKeyString string
|
|
var recoveryKey string
|
|
var cryptoDBPath string
|
|
var credentialsPath string
|
|
var cryptoResetOnMismatch bool
|
|
|
|
type storedCredentials struct {
|
|
UserID string `json:"user_id"`
|
|
AccessToken string `json:"access_token"`
|
|
DeviceID string `json:"device_id"`
|
|
}
|
|
|
|
type authManager struct {
|
|
mu sync.Mutex
|
|
client *mautrix.Client
|
|
credentialsPath string
|
|
|
|
username string
|
|
password string
|
|
|
|
afterLogin func(ctx context.Context) error
|
|
}
|
|
|
|
func isInvalidToken(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MMissingToken) {
|
|
return true
|
|
}
|
|
var httpErr mautrix.HTTPError
|
|
if errors.As(err, &httpErr) {
|
|
// Some servers return 401 without a Matrix errcode.
|
|
return httpErr.IsStatus(401)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (a *authManager) relogin(ctx context.Context) error {
|
|
a.mu.Lock()
|
|
defer a.mu.Unlock()
|
|
|
|
if a.username == "" || a.password == "" {
|
|
return errors.New("missing MATRIX_USERNAME or MATRIX_PASSWORD for re-login")
|
|
}
|
|
|
|
preferredDeviceID := a.client.DeviceID
|
|
log.Printf("Re-logging in to Matrix (preferred device_id=%q)", preferredDeviceID)
|
|
_, err := a.client.Login(ctx, &mautrix.ReqLogin{
|
|
Type: mautrix.AuthTypePassword,
|
|
Identifier: mautrix.UserIdentifier{
|
|
Type: mautrix.IdentifierTypeUser,
|
|
User: a.username,
|
|
},
|
|
Password: a.password,
|
|
DeviceID: preferredDeviceID,
|
|
StoreCredentials: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := saveStoredCredentials(a.credentialsPath, &storedCredentials{
|
|
UserID: a.client.UserID.String(),
|
|
AccessToken: a.client.AccessToken,
|
|
DeviceID: a.client.DeviceID.String(),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
log.Println("Updated credentials in", a.credentialsPath)
|
|
|
|
if a.afterLogin != nil {
|
|
if err := a.afterLogin(ctx); err != nil {
|
|
if isOlmAccountMismatch(err) {
|
|
log.Println("Detected olm account mismatch with server keys")
|
|
if cryptoResetOnMismatch {
|
|
if resetErr := resetCryptoState(); resetErr != nil {
|
|
log.Fatal(resetErr)
|
|
}
|
|
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
|
|
}
|
|
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
|
|
}
|
|
log.Println("Post-login recovery key verification failed:", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func runSyncWithAutoRelogin(ctx context.Context, client *mautrix.Client, auth *authManager) {
|
|
backoff := 2 * time.Second
|
|
maxBackoff := 30 * time.Second
|
|
for {
|
|
err := client.SyncWithContext(ctx)
|
|
if err == nil {
|
|
return
|
|
}
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
if isInvalidToken(err) {
|
|
log.Println("Matrix token invalid/expired; attempting re-login:", err)
|
|
if err2 := auth.relogin(ctx); err2 != nil {
|
|
log.Println("Re-login failed:", err2)
|
|
time.Sleep(backoff)
|
|
backoff *= 2
|
|
if backoff > maxBackoff {
|
|
backoff = maxBackoff
|
|
}
|
|
continue
|
|
}
|
|
backoff = 2 * time.Second
|
|
continue
|
|
}
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) {
|
|
pickleKey := []byte(pickleKeyString)
|
|
|
|
if cryptoDBPath != "" {
|
|
dir := filepath.Dir(cryptoDBPath)
|
|
if dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, cryptoDBPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = helper.Init(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return helper, nil
|
|
}
|
|
|
|
func verifyWithRecoveryKey(machine *crypto.OlmMachine) (err error) {
|
|
ctx := context.Background()
|
|
|
|
keyId, keyData, err := machine.SSSS.GetDefaultKeyData(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
key, err := keyData.VerifyRecoveryKey(keyId, recoveryKey)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.FetchCrossSigningKeysFromSSSS(ctx, key)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.SignOwnDevice(ctx, machine.OwnIdentity())
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.SignOwnMasterKey(ctx)
|
|
|
|
return
|
|
}
|
|
|
|
func envOrFatal(key string) string {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
log.Fatalf("missing required env var: %s", key)
|
|
}
|
|
return val
|
|
}
|
|
|
|
func envOrDefault(key string, def string) string {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
return def
|
|
}
|
|
return val
|
|
}
|
|
|
|
func loadStoredCredentials(path string) (*storedCredentials, error) {
|
|
if path == "" {
|
|
return nil, nil
|
|
}
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
var creds storedCredentials
|
|
if err := json.Unmarshal(data, &creds); err != nil {
|
|
return nil, err
|
|
}
|
|
return &creds, nil
|
|
}
|
|
|
|
func saveStoredCredentials(path string, creds *storedCredentials) error {
|
|
if path == "" {
|
|
return nil
|
|
}
|
|
dir := filepath.Dir(path)
|
|
if dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
data, err := json.MarshalIndent(creds, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, data, 0o600)
|
|
}
|
|
|
|
func readDeviceIDFromCryptoDB(path string) (id.DeviceID, error) {
|
|
if path == "" {
|
|
return "", nil
|
|
}
|
|
if _, err := os.Stat(path); err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
|
|
// Match mautrix's default DSN style.
|
|
dsn := fmt.Sprintf("file:%s?_txlock=immediate", path)
|
|
db, err := sql.Open("sqlite3", dsn)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer db.Close()
|
|
|
|
var deviceID id.DeviceID
|
|
// CryptoHelper defaults DBAccountID to "", so account_id is "" unless explicitly set.
|
|
err = db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = '' LIMIT 1").Scan(&deviceID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return "", nil
|
|
}
|
|
// Fresh/invalid DB.
|
|
if strings.Contains(err.Error(), "no such table") {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
return deviceID, nil
|
|
}
|
|
|
|
func loadConfig() {
|
|
homeserver = envOrFatal("MATRIX_HOMESERVER")
|
|
username = envOrFatal("MATRIX_USERNAME")
|
|
password = envOrFatal("MATRIX_PASSWORD")
|
|
roomID = envOrDefault("MATRIX_ROOM_ID", "")
|
|
pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY")
|
|
recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY")
|
|
cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db")
|
|
credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json")
|
|
cryptoResetOnMismatch = strings.EqualFold(envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", ""), "true") ||
|
|
envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", "") == "1"
|
|
}
|
|
|
|
func isOlmAccountMismatch(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
return strings.Contains(err.Error(), "olm account is not marked as shared")
|
|
}
|
|
|
|
func resetCryptoState() error {
|
|
var resetErr error
|
|
if cryptoDBPath != "" {
|
|
if err := os.Remove(cryptoDBPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
resetErr = err
|
|
}
|
|
}
|
|
if credentialsPath != "" {
|
|
if err := os.Remove(credentialsPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
resetErr = err
|
|
}
|
|
}
|
|
return resetErr
|
|
}
|
|
|
|
func main() {
|
|
var err error
|
|
loadConfig()
|
|
bot.Load()
|
|
|
|
cryptoDeviceID, err := readDeviceIDFromCryptoDB(cryptoDBPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
forceLogin := false
|
|
|
|
stored, err := loadStoredCredentials(credentialsPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
var cachedUserID string
|
|
var cachedAccessToken string
|
|
var cachedDeviceID string
|
|
if stored != nil {
|
|
cachedUserID = stored.UserID
|
|
cachedAccessToken = stored.AccessToken
|
|
cachedDeviceID = stored.DeviceID
|
|
if cachedUserID != "" || cachedAccessToken != "" || cachedDeviceID != "" {
|
|
log.Println("Loaded credentials from", credentialsPath)
|
|
}
|
|
}
|
|
|
|
client, err := mautrix.NewClient(homeserver, id.UserID(cachedUserID), cachedAccessToken)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Device ID source of truth:
|
|
// 1. crypto DB (if present)
|
|
// 2. cached credentials
|
|
if cryptoDeviceID != "" {
|
|
client.DeviceID = cryptoDeviceID
|
|
if cachedDeviceID != "" && cachedDeviceID != cryptoDeviceID.String() {
|
|
log.Printf("Device ID mismatch between credentials and crypto DB (%q != %q). Will re-login using crypto DB device ID.", cachedDeviceID, cryptoDeviceID)
|
|
forceLogin = true
|
|
}
|
|
} else if cachedDeviceID != "" {
|
|
client.DeviceID = id.DeviceID(cachedDeviceID)
|
|
}
|
|
|
|
if client.AccessToken == "" || client.UserID == "" || client.DeviceID == "" || forceLogin {
|
|
log.Println("Logging in to Matrix")
|
|
_, err = client.Login(context.Background(), &mautrix.ReqLogin{
|
|
Type: mautrix.AuthTypePassword,
|
|
Identifier: mautrix.UserIdentifier{
|
|
Type: mautrix.IdentifierTypeUser,
|
|
User: username,
|
|
},
|
|
Password: password,
|
|
DeviceID: client.DeviceID,
|
|
StoreCredentials: true,
|
|
})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
if err := saveStoredCredentials(credentialsPath, &storedCredentials{
|
|
UserID: client.UserID.String(),
|
|
AccessToken: client.AccessToken,
|
|
DeviceID: client.DeviceID.String(),
|
|
}); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
log.Println("Saved credentials to", credentialsPath)
|
|
}
|
|
auth := &authManager{
|
|
client: client,
|
|
credentialsPath: credentialsPath,
|
|
username: username,
|
|
password: password,
|
|
}
|
|
|
|
syncer := mautrix.NewDefaultSyncer()
|
|
client.Syncer = syncer
|
|
|
|
cryptoHelper, err := setupCryptoHelper(client)
|
|
if err != nil {
|
|
if isOlmAccountMismatch(err) {
|
|
log.Println("Detected olm account mismatch with server keys")
|
|
if cryptoResetOnMismatch {
|
|
if resetErr := resetCryptoState(); resetErr != nil {
|
|
log.Fatal(resetErr)
|
|
}
|
|
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
|
|
}
|
|
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
|
|
}
|
|
log.Fatal(err)
|
|
}
|
|
client.Crypto = cryptoHelper
|
|
auth.afterLogin = func(ctx context.Context) error {
|
|
return verifyWithRecoveryKey(cryptoHelper.Machine())
|
|
}
|
|
|
|
syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) {
|
|
// Ignore our own messages
|
|
if client.UserID != "" && evt.Sender == client.UserID {
|
|
return
|
|
}
|
|
|
|
content := evt.Content.AsMessage()
|
|
if content.MsgType != event.MsgText {
|
|
return
|
|
}
|
|
|
|
log.Printf("Message from %s: %s\n", evt.Sender, content.Body)
|
|
|
|
response := bot.HandleCommand(content.Body, evt.Sender.String(), evt.RoomID.String(),
|
|
ctx, client, evt, &event.RelatesTo{
|
|
InReplyTo: &event.InReplyTo{EventID: evt.ID},
|
|
})
|
|
if response == nil {
|
|
return
|
|
}
|
|
|
|
for _, resp := range response {
|
|
switch r := resp.(type) {
|
|
case event.MessageEventContent:
|
|
_, err := client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, r)
|
|
if err != nil {
|
|
log.Println("Send error:", err)
|
|
}
|
|
default:
|
|
log.Println("Unknown response type")
|
|
}
|
|
}
|
|
})
|
|
ready := make(chan struct{})
|
|
var once sync.Once
|
|
syncer.OnSync(func(ctx context.Context, resp *mautrix.RespSync, since string) bool {
|
|
once.Do(func() { close(ready) })
|
|
return true
|
|
})
|
|
|
|
go func() {
|
|
runSyncWithAutoRelogin(context.Background(), client, auth)
|
|
}()
|
|
|
|
log.Println("Waiting for initial sync...")
|
|
<-ready
|
|
log.Println("Sync complete")
|
|
|
|
if err := verifyWithRecoveryKey(cryptoHelper.Machine()); err != nil {
|
|
if isOlmAccountMismatch(err) {
|
|
log.Println("Detected olm account mismatch with server keys")
|
|
if cryptoResetOnMismatch {
|
|
if resetErr := resetCryptoState(); resetErr != nil {
|
|
log.Fatal(resetErr)
|
|
}
|
|
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
|
|
}
|
|
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
|
|
}
|
|
log.Fatal(err)
|
|
}
|
|
|
|
log.Println("Bot is running...")
|
|
select {}
|
|
}
|