package main import ( "context" "database/sql" "encoding/json" "errors" "fmt" "log" "matrix-bot/bot" "os" "path/filepath" "strings" "sync" "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" _ "github.com/mattn/go-sqlite3" ) var homeserver string var username string var password string var roomID string var pickleKeyString string var recoveryKey string var cryptoDBPath string var credentialsPath string var cryptoResetOnMismatch bool type storedCredentials struct { UserID string `json:"user_id"` AccessToken string `json:"access_token"` DeviceID string `json:"device_id"` } type authManager struct { mu sync.Mutex client *mautrix.Client credentialsPath string username string password string afterLogin func(ctx context.Context) error } func isInvalidToken(err error) bool { if err == nil { return false } if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MMissingToken) { return true } var httpErr mautrix.HTTPError if errors.As(err, &httpErr) { // Some servers return 401 without a Matrix errcode. return httpErr.IsStatus(401) } return false } func (a *authManager) relogin(ctx context.Context) error { a.mu.Lock() defer a.mu.Unlock() if a.username == "" || a.password == "" { return errors.New("missing MATRIX_USERNAME or MATRIX_PASSWORD for re-login") } preferredDeviceID := a.client.DeviceID log.Printf("Re-logging in to Matrix (preferred device_id=%q)", preferredDeviceID) _, err := a.client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: a.username, }, Password: a.password, DeviceID: preferredDeviceID, StoreCredentials: true, }) if err != nil { return err } if err := saveStoredCredentials(a.credentialsPath, &storedCredentials{ UserID: a.client.UserID.String(), AccessToken: a.client.AccessToken, DeviceID: a.client.DeviceID.String(), }); err != nil { return err } log.Println("Updated credentials in", a.credentialsPath) if a.afterLogin != nil { if err := a.afterLogin(ctx); err != nil { if isOlmAccountMismatch(err) { log.Println("Detected olm account mismatch with server keys") if cryptoResetOnMismatch { if resetErr := resetCryptoState(); resetErr != nil { log.Fatal(resetErr) } log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.") } log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.") } log.Println("Post-login recovery key verification failed:", err) } } return nil } func runSyncWithAutoRelogin(ctx context.Context, client *mautrix.Client, auth *authManager) { backoff := 2 * time.Second maxBackoff := 30 * time.Second for { err := client.SyncWithContext(ctx) if err == nil { return } if ctx.Err() != nil { return } if isInvalidToken(err) { log.Println("Matrix token invalid/expired; attempting re-login:", err) if err2 := auth.relogin(ctx); err2 != nil { log.Println("Re-login failed:", err2) time.Sleep(backoff) backoff *= 2 if backoff > maxBackoff { backoff = maxBackoff } continue } backoff = 2 * time.Second continue } log.Fatal(err) } } func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) { pickleKey := []byte(pickleKeyString) if cryptoDBPath != "" { dir := filepath.Dir(cryptoDBPath) if dir != "." { if err := os.MkdirAll(dir, 0o755); err != nil { return nil, err } } } helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, cryptoDBPath) if err != nil { return nil, err } err = helper.Init(context.Background()) if err != nil { return nil, err } return helper, nil } func verifyWithRecoveryKey(machine *crypto.OlmMachine) (err error) { ctx := context.Background() keyId, keyData, err := machine.SSSS.GetDefaultKeyData(ctx) if err != nil { return } key, err := keyData.VerifyRecoveryKey(keyId, recoveryKey) if err != nil { return } err = machine.FetchCrossSigningKeysFromSSSS(ctx, key) if err != nil { return } err = machine.SignOwnDevice(ctx, machine.OwnIdentity()) if err != nil { return } err = machine.SignOwnMasterKey(ctx) return } func envOrFatal(key string) string { val := os.Getenv(key) if val == "" { log.Fatalf("missing required env var: %s", key) } return val } func envOrDefault(key string, def string) string { val := os.Getenv(key) if val == "" { return def } return val } func loadStoredCredentials(path string) (*storedCredentials, error) { if path == "" { return nil, nil } data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, err } var creds storedCredentials if err := json.Unmarshal(data, &creds); err != nil { return nil, err } return &creds, nil } func saveStoredCredentials(path string, creds *storedCredentials) error { if path == "" { return nil } dir := filepath.Dir(path) if dir != "." { if err := os.MkdirAll(dir, 0o755); err != nil { return err } } data, err := json.MarshalIndent(creds, "", " ") if err != nil { return err } return os.WriteFile(path, data, 0o600) } func readDeviceIDFromCryptoDB(path string) (id.DeviceID, error) { if path == "" { return "", nil } if _, err := os.Stat(path); err != nil { if errors.Is(err, os.ErrNotExist) { return "", nil } return "", err } // Match mautrix's default DSN style. dsn := fmt.Sprintf("file:%s?_txlock=immediate", path) db, err := sql.Open("sqlite3", dsn) if err != nil { return "", err } defer db.Close() var deviceID id.DeviceID // CryptoHelper defaults DBAccountID to "", so account_id is "" unless explicitly set. err = db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = '' LIMIT 1").Scan(&deviceID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return "", nil } // Fresh/invalid DB. if strings.Contains(err.Error(), "no such table") { return "", nil } return "", err } return deviceID, nil } func loadConfig() { homeserver = envOrFatal("MATRIX_HOMESERVER") username = envOrFatal("MATRIX_USERNAME") password = envOrFatal("MATRIX_PASSWORD") roomID = envOrDefault("MATRIX_ROOM_ID", "") pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY") recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY") cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db") credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json") cryptoResetOnMismatch = strings.EqualFold(envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", ""), "true") || envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", "") == "1" } func isOlmAccountMismatch(err error) bool { if err == nil { return false } return strings.Contains(err.Error(), "olm account is not marked as shared") } func resetCryptoState() error { var resetErr error if cryptoDBPath != "" { if err := os.Remove(cryptoDBPath); err != nil && !errors.Is(err, os.ErrNotExist) { resetErr = err } } if credentialsPath != "" { if err := os.Remove(credentialsPath); err != nil && !errors.Is(err, os.ErrNotExist) { resetErr = err } } return resetErr } func main() { var err error loadConfig() bot.Load() cryptoDeviceID, err := readDeviceIDFromCryptoDB(cryptoDBPath) if err != nil { log.Fatal(err) } forceLogin := false stored, err := loadStoredCredentials(credentialsPath) if err != nil { log.Fatal(err) } var cachedUserID string var cachedAccessToken string var cachedDeviceID string if stored != nil { cachedUserID = stored.UserID cachedAccessToken = stored.AccessToken cachedDeviceID = stored.DeviceID if cachedUserID != "" || cachedAccessToken != "" || cachedDeviceID != "" { log.Println("Loaded credentials from", credentialsPath) } } client, err := mautrix.NewClient(homeserver, id.UserID(cachedUserID), cachedAccessToken) if err != nil { log.Fatal(err) } // Device ID source of truth: // 1. crypto DB (if present) // 2. cached credentials if cryptoDeviceID != "" { client.DeviceID = cryptoDeviceID if cachedDeviceID != "" && cachedDeviceID != cryptoDeviceID.String() { log.Printf("Device ID mismatch between credentials and crypto DB (%q != %q). Will re-login using crypto DB device ID.", cachedDeviceID, cryptoDeviceID) forceLogin = true } } else if cachedDeviceID != "" { client.DeviceID = id.DeviceID(cachedDeviceID) } if client.AccessToken == "" || client.UserID == "" || client.DeviceID == "" || forceLogin { log.Println("Logging in to Matrix") _, err = client.Login(context.Background(), &mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: username, }, Password: password, DeviceID: client.DeviceID, StoreCredentials: true, }) if err != nil { log.Fatal(err) } if err := saveStoredCredentials(credentialsPath, &storedCredentials{ UserID: client.UserID.String(), AccessToken: client.AccessToken, DeviceID: client.DeviceID.String(), }); err != nil { log.Fatal(err) } log.Println("Saved credentials to", credentialsPath) } auth := &authManager{ client: client, credentialsPath: credentialsPath, username: username, password: password, } syncer := mautrix.NewDefaultSyncer() client.Syncer = syncer cryptoHelper, err := setupCryptoHelper(client) if err != nil { if isOlmAccountMismatch(err) { log.Println("Detected olm account mismatch with server keys") if cryptoResetOnMismatch { if resetErr := resetCryptoState(); resetErr != nil { log.Fatal(resetErr) } log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.") } log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.") } log.Fatal(err) } client.Crypto = cryptoHelper auth.afterLogin = func(ctx context.Context) error { return verifyWithRecoveryKey(cryptoHelper.Machine()) } syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) { // Ignore our own messages if client.UserID != "" && evt.Sender == client.UserID { return } content := evt.Content.AsMessage() if content.MsgType != event.MsgText { return } log.Printf("Message from %s: %s\n", evt.Sender, content.Body) response := bot.HandleCommand(content.Body, evt.Sender.String(), evt.RoomID.String(), ctx, client, evt, &event.RelatesTo{ InReplyTo: &event.InReplyTo{EventID: evt.ID}, }) if response == nil { return } for _, resp := range response { switch r := resp.(type) { case event.MessageEventContent: _, err := client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, r) if err != nil { log.Println("Send error:", err) } default: log.Println("Unknown response type") } } }) ready := make(chan struct{}) var once sync.Once syncer.OnSync(func(ctx context.Context, resp *mautrix.RespSync, since string) bool { once.Do(func() { close(ready) }) return true }) go func() { runSyncWithAutoRelogin(context.Background(), client, auth) }() log.Println("Waiting for initial sync...") <-ready log.Println("Sync complete") if err := verifyWithRecoveryKey(cryptoHelper.Machine()); err != nil { if isOlmAccountMismatch(err) { log.Println("Detected olm account mismatch with server keys") if cryptoResetOnMismatch { if resetErr := resetCryptoState(); resetErr != nil { log.Fatal(resetErr) } log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.") } log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.") } log.Fatal(err) } log.Println("Bot is running...") select {} }