matrix-bot/main.go
shinya f35433af77
All checks were successful
Deploy Matrix Bot / deploy (push) Successful in 10s
leaderboard added
2026-05-16 09:58:43 +02:00

482 lines
12 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"matrix-bot/bot"
"os"
"path/filepath"
"strings"
"sync"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
_ "github.com/mattn/go-sqlite3"
)
var homeserver string
var username string
var password string
var roomID string
var pickleKeyString string
var recoveryKey string
var cryptoDBPath string
var credentialsPath string
var cryptoResetOnMismatch bool
type storedCredentials struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
DeviceID string `json:"device_id"`
}
type authManager struct {
mu sync.Mutex
client *mautrix.Client
credentialsPath string
username string
password string
afterLogin func(ctx context.Context) error
}
func isInvalidToken(err error) bool {
if err == nil {
return false
}
if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MMissingToken) {
return true
}
var httpErr mautrix.HTTPError
if errors.As(err, &httpErr) {
// Some servers return 401 without a Matrix errcode.
return httpErr.IsStatus(401)
}
return false
}
func (a *authManager) relogin(ctx context.Context) error {
a.mu.Lock()
defer a.mu.Unlock()
if a.username == "" || a.password == "" {
return errors.New("missing MATRIX_USERNAME or MATRIX_PASSWORD for re-login")
}
preferredDeviceID := a.client.DeviceID
log.Printf("Re-logging in to Matrix (preferred device_id=%q)", preferredDeviceID)
_, err := a.client.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: a.username,
},
Password: a.password,
DeviceID: preferredDeviceID,
StoreCredentials: true,
})
if err != nil {
return err
}
if err := saveStoredCredentials(a.credentialsPath, &storedCredentials{
UserID: a.client.UserID.String(),
AccessToken: a.client.AccessToken,
DeviceID: a.client.DeviceID.String(),
}); err != nil {
return err
}
log.Println("Updated credentials in", a.credentialsPath)
if a.afterLogin != nil {
if err := a.afterLogin(ctx); err != nil {
if isOlmAccountMismatch(err) {
log.Println("Detected olm account mismatch with server keys")
if cryptoResetOnMismatch {
if resetErr := resetCryptoState(); resetErr != nil {
log.Fatal(resetErr)
}
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
}
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
}
log.Println("Post-login recovery key verification failed:", err)
}
}
return nil
}
func runSyncWithAutoRelogin(ctx context.Context, client *mautrix.Client, auth *authManager) {
backoff := 2 * time.Second
maxBackoff := 30 * time.Second
for {
err := client.SyncWithContext(ctx)
if err == nil {
return
}
if ctx.Err() != nil {
return
}
if isInvalidToken(err) {
log.Println("Matrix token invalid/expired; attempting re-login:", err)
if err2 := auth.relogin(ctx); err2 != nil {
log.Println("Re-login failed:", err2)
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
backoff = 2 * time.Second
continue
}
log.Fatal(err)
}
}
func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) {
pickleKey := []byte(pickleKeyString)
if cryptoDBPath != "" {
dir := filepath.Dir(cryptoDBPath)
if dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
}
}
helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, cryptoDBPath)
if err != nil {
return nil, err
}
err = helper.Init(context.Background())
if err != nil {
return nil, err
}
return helper, nil
}
func verifyWithRecoveryKey(machine *crypto.OlmMachine) (err error) {
ctx := context.Background()
keyId, keyData, err := machine.SSSS.GetDefaultKeyData(ctx)
if err != nil {
return
}
key, err := keyData.VerifyRecoveryKey(keyId, recoveryKey)
if err != nil {
return
}
err = machine.FetchCrossSigningKeysFromSSSS(ctx, key)
if err != nil {
return
}
err = machine.SignOwnDevice(ctx, machine.OwnIdentity())
if err != nil {
return
}
err = machine.SignOwnMasterKey(ctx)
return
}
func envOrFatal(key string) string {
val := os.Getenv(key)
if val == "" {
log.Fatalf("missing required env var: %s", key)
}
return val
}
func envOrDefault(key string, def string) string {
val := os.Getenv(key)
if val == "" {
return def
}
return val
}
func loadStoredCredentials(path string) (*storedCredentials, error) {
if path == "" {
return nil, nil
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var creds storedCredentials
if err := json.Unmarshal(data, &creds); err != nil {
return nil, err
}
return &creds, nil
}
func saveStoredCredentials(path string, creds *storedCredentials) error {
if path == "" {
return nil
}
dir := filepath.Dir(path)
if dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
}
data, err := json.MarshalIndent(creds, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
func readDeviceIDFromCryptoDB(path string) (id.DeviceID, error) {
if path == "" {
return "", nil
}
if _, err := os.Stat(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return "", nil
}
return "", err
}
// Match mautrix's default DSN style.
dsn := fmt.Sprintf("file:%s?_txlock=immediate", path)
db, err := sql.Open("sqlite3", dsn)
if err != nil {
return "", err
}
defer db.Close()
var deviceID id.DeviceID
// CryptoHelper defaults DBAccountID to "", so account_id is "" unless explicitly set.
err = db.QueryRow("SELECT device_id FROM crypto_account WHERE account_id = '' LIMIT 1").Scan(&deviceID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", nil
}
// Fresh/invalid DB.
if strings.Contains(err.Error(), "no such table") {
return "", nil
}
return "", err
}
return deviceID, nil
}
func loadConfig() {
homeserver = envOrFatal("MATRIX_HOMESERVER")
username = envOrFatal("MATRIX_USERNAME")
password = envOrFatal("MATRIX_PASSWORD")
roomID = envOrDefault("MATRIX_ROOM_ID", "")
pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY")
recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY")
cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db")
credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json")
cryptoResetOnMismatch = strings.EqualFold(envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", ""), "true") ||
envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", "") == "1"
}
func isOlmAccountMismatch(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "olm account is not marked as shared")
}
func resetCryptoState() error {
var resetErr error
if cryptoDBPath != "" {
if err := os.Remove(cryptoDBPath); err != nil && !errors.Is(err, os.ErrNotExist) {
resetErr = err
}
}
if credentialsPath != "" {
if err := os.Remove(credentialsPath); err != nil && !errors.Is(err, os.ErrNotExist) {
resetErr = err
}
}
return resetErr
}
func main() {
var err error
loadConfig()
bot.Load()
cryptoDeviceID, err := readDeviceIDFromCryptoDB(cryptoDBPath)
if err != nil {
log.Fatal(err)
}
forceLogin := false
stored, err := loadStoredCredentials(credentialsPath)
if err != nil {
log.Fatal(err)
}
var cachedUserID string
var cachedAccessToken string
var cachedDeviceID string
if stored != nil {
cachedUserID = stored.UserID
cachedAccessToken = stored.AccessToken
cachedDeviceID = stored.DeviceID
if cachedUserID != "" || cachedAccessToken != "" || cachedDeviceID != "" {
log.Println("Loaded credentials from", credentialsPath)
}
}
client, err := mautrix.NewClient(homeserver, id.UserID(cachedUserID), cachedAccessToken)
if err != nil {
log.Fatal(err)
}
// Device ID source of truth:
// 1. crypto DB (if present)
// 2. cached credentials
if cryptoDeviceID != "" {
client.DeviceID = cryptoDeviceID
if cachedDeviceID != "" && cachedDeviceID != cryptoDeviceID.String() {
log.Printf("Device ID mismatch between credentials and crypto DB (%q != %q). Will re-login using crypto DB device ID.", cachedDeviceID, cryptoDeviceID)
forceLogin = true
}
} else if cachedDeviceID != "" {
client.DeviceID = id.DeviceID(cachedDeviceID)
}
if client.AccessToken == "" || client.UserID == "" || client.DeviceID == "" || forceLogin {
log.Println("Logging in to Matrix")
_, err = client.Login(context.Background(), &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: username,
},
Password: password,
DeviceID: client.DeviceID,
StoreCredentials: true,
})
if err != nil {
log.Fatal(err)
}
if err := saveStoredCredentials(credentialsPath, &storedCredentials{
UserID: client.UserID.String(),
AccessToken: client.AccessToken,
DeviceID: client.DeviceID.String(),
}); err != nil {
log.Fatal(err)
}
log.Println("Saved credentials to", credentialsPath)
}
auth := &authManager{
client: client,
credentialsPath: credentialsPath,
username: username,
password: password,
}
syncer := mautrix.NewDefaultSyncer()
client.Syncer = syncer
cryptoHelper, err := setupCryptoHelper(client)
if err != nil {
if isOlmAccountMismatch(err) {
log.Println("Detected olm account mismatch with server keys")
if cryptoResetOnMismatch {
if resetErr := resetCryptoState(); resetErr != nil {
log.Fatal(resetErr)
}
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
}
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
}
log.Fatal(err)
}
client.Crypto = cryptoHelper
auth.afterLogin = func(ctx context.Context) error {
return verifyWithRecoveryKey(cryptoHelper.Machine())
}
syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) {
// Ignore our own messages
if client.UserID != "" && evt.Sender == client.UserID {
return
}
content := evt.Content.AsMessage()
if content.MsgType != event.MsgText {
return
}
log.Printf("Message from %s: %s\n", evt.Sender, content.Body)
response := bot.HandleCommand(content.Body, evt.Sender.String(), evt.RoomID.String(),
ctx, client, evt, &event.RelatesTo{
InReplyTo: &event.InReplyTo{EventID: evt.ID},
})
if response == nil {
return
}
for _, resp := range response {
switch r := resp.(type) {
case event.MessageEventContent:
_, err := client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, r)
if err != nil {
log.Println("Send error:", err)
}
default:
log.Println("Unknown response type")
}
}
})
ready := make(chan struct{})
var once sync.Once
syncer.OnSync(func(ctx context.Context, resp *mautrix.RespSync, since string) bool {
once.Do(func() { close(ready) })
return true
})
go func() {
runSyncWithAutoRelogin(context.Background(), client, auth)
}()
log.Println("Waiting for initial sync...")
<-ready
log.Println("Sync complete")
if err := verifyWithRecoveryKey(cryptoHelper.Machine()); err != nil {
if isOlmAccountMismatch(err) {
log.Println("Detected olm account mismatch with server keys")
if cryptoResetOnMismatch {
if resetErr := resetCryptoState(); resetErr != nil {
log.Fatal(resetErr)
}
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
}
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
}
log.Fatal(err)
}
log.Println("Bot is running...")
select {}
}