330 lines
7.8 KiB
Go
330 lines
7.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"matrix-bot/bot"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/crypto/cryptohelper"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
var homeserver string
|
|
var username string
|
|
var password string
|
|
var roomID string
|
|
var userId string
|
|
var accessToken string
|
|
var deviceId string
|
|
var pickleKeyString string
|
|
var recoveryKey string
|
|
var cryptoDBPath string
|
|
var credentialsPath string
|
|
var cryptoResetOnMismatch bool
|
|
|
|
type storedCredentials struct {
|
|
UserID string `json:"user_id"`
|
|
AccessToken string `json:"access_token"`
|
|
DeviceID string `json:"device_id"`
|
|
}
|
|
|
|
func setupCryptoHelper(cli *mautrix.Client) (*cryptohelper.CryptoHelper, error) {
|
|
// remember to use a secure key for the pickle key in production
|
|
pickleKey := []byte(pickleKeyString)
|
|
|
|
if cryptoDBPath != "" {
|
|
dir := filepath.Dir(cryptoDBPath)
|
|
if dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
helper, err := cryptohelper.NewCryptoHelper(cli, pickleKey, cryptoDBPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = helper.Init(context.Background())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return helper, nil
|
|
}
|
|
|
|
func verifyWithRecoveryKey(machine *crypto.OlmMachine) (err error) {
|
|
ctx := context.Background()
|
|
|
|
keyId, keyData, err := machine.SSSS.GetDefaultKeyData(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
key, err := keyData.VerifyRecoveryKey(keyId, recoveryKey)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.FetchCrossSigningKeysFromSSSS(ctx, key)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.SignOwnDevice(ctx, machine.OwnIdentity())
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = machine.SignOwnMasterKey(ctx)
|
|
|
|
return
|
|
}
|
|
|
|
func envOrFatal(key string) string {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
log.Fatalf("missing required env var: %s", key)
|
|
}
|
|
return val
|
|
}
|
|
|
|
func envOrDefault(key string, def string) string {
|
|
val := os.Getenv(key)
|
|
if val == "" {
|
|
return def
|
|
}
|
|
return val
|
|
}
|
|
|
|
func loadStoredCredentials(path string) (*storedCredentials, error) {
|
|
if path == "" {
|
|
return nil, nil
|
|
}
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
var creds storedCredentials
|
|
if err := json.Unmarshal(data, &creds); err != nil {
|
|
return nil, err
|
|
}
|
|
return &creds, nil
|
|
}
|
|
|
|
func saveStoredCredentials(path string, creds *storedCredentials) error {
|
|
if path == "" {
|
|
return nil
|
|
}
|
|
dir := filepath.Dir(path)
|
|
if dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
data, err := json.MarshalIndent(creds, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(path, data, 0o600)
|
|
}
|
|
|
|
func loadConfig() {
|
|
homeserver = envOrFatal("MATRIX_HOMESERVER")
|
|
username = envOrDefault("MATRIX_USERNAME", "")
|
|
password = envOrDefault("MATRIX_PASSWORD", "")
|
|
roomID = envOrDefault("MATRIX_ROOM_ID", "")
|
|
userId = envOrDefault("MATRIX_USER_ID", "")
|
|
accessToken = envOrDefault("MATRIX_ACCESS_TOKEN", "")
|
|
deviceId = envOrDefault("MATRIX_DEVICE_ID", "")
|
|
pickleKeyString = envOrFatal("MATRIX_PICKLE_KEY")
|
|
recoveryKey = envOrFatal("MATRIX_RECOVERY_KEY")
|
|
cryptoDBPath = envOrDefault("MATRIX_CRYPTO_DB", "crypto.db")
|
|
credentialsPath = envOrDefault("MATRIX_CREDENTIALS_PATH", "/data/credentials.json")
|
|
cryptoResetOnMismatch = strings.EqualFold(envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", ""), "true") ||
|
|
envOrDefault("MATRIX_CRYPTO_RESET_ON_MISMATCH", "") == "1"
|
|
}
|
|
|
|
func isOlmAccountMismatch(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
return strings.Contains(err.Error(), "olm account is not marked as shared")
|
|
}
|
|
|
|
func resetCryptoState() error {
|
|
var resetErr error
|
|
if cryptoDBPath != "" {
|
|
if err := os.Remove(cryptoDBPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
resetErr = err
|
|
}
|
|
}
|
|
if credentialsPath != "" {
|
|
if err := os.Remove(credentialsPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
resetErr = err
|
|
}
|
|
}
|
|
return resetErr
|
|
}
|
|
|
|
func main() {
|
|
var err error
|
|
loadConfig()
|
|
bot.Load()
|
|
|
|
stored, err := loadStoredCredentials(credentialsPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
if stored != nil {
|
|
if accessToken == "" {
|
|
accessToken = stored.AccessToken
|
|
}
|
|
if deviceId == "" {
|
|
deviceId = stored.DeviceID
|
|
}
|
|
if userId == "" {
|
|
userId = stored.UserID
|
|
}
|
|
if accessToken != "" || deviceId != "" || userId != "" {
|
|
log.Println("Loaded credentials from", credentialsPath)
|
|
}
|
|
}
|
|
|
|
client, err := mautrix.NewClient(homeserver, id.UserID(userId), accessToken)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if accessToken == "" || deviceId == "" {
|
|
if username == "" || password == "" {
|
|
log.Fatal("missing MATRIX_USERNAME or MATRIX_PASSWORD for credential bootstrap")
|
|
}
|
|
log.Println("Logging in to Matrix to bootstrap credentials")
|
|
_, err = client.Login(context.Background(), &mautrix.ReqLogin{
|
|
Type: mautrix.AuthTypePassword,
|
|
Identifier: mautrix.UserIdentifier{
|
|
Type: mautrix.IdentifierTypeUser,
|
|
User: username,
|
|
},
|
|
Password: password,
|
|
StoreCredentials: true,
|
|
})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
accessToken = client.AccessToken
|
|
userId = client.UserID.String()
|
|
deviceId = client.DeviceID.String()
|
|
if err := saveStoredCredentials(credentialsPath, &storedCredentials{
|
|
UserID: userId,
|
|
AccessToken: accessToken,
|
|
DeviceID: deviceId,
|
|
}); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
log.Println("Saved credentials to", credentialsPath)
|
|
}
|
|
|
|
if userId == "" {
|
|
log.Fatal("missing MATRIX_USER_ID and no stored credentials")
|
|
}
|
|
|
|
client.DeviceID = id.DeviceID(deviceId)
|
|
|
|
syncer := mautrix.NewDefaultSyncer()
|
|
client.Syncer = syncer
|
|
|
|
cryptoHelper, err := setupCryptoHelper(client)
|
|
if err != nil {
|
|
if isOlmAccountMismatch(err) {
|
|
log.Println("Detected olm account mismatch with server keys")
|
|
if cryptoResetOnMismatch {
|
|
if resetErr := resetCryptoState(); resetErr != nil {
|
|
log.Fatal(resetErr)
|
|
}
|
|
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
|
|
}
|
|
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
|
|
}
|
|
log.Fatal(err)
|
|
}
|
|
client.Crypto = cryptoHelper
|
|
|
|
syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) {
|
|
// Ignore our own messages
|
|
if evt.Sender.String() == userId {
|
|
return
|
|
}
|
|
|
|
content := evt.Content.AsMessage()
|
|
if content.MsgType != event.MsgText {
|
|
return
|
|
}
|
|
|
|
log.Printf("Message from %s: %s\n", evt.Sender, content.Body)
|
|
|
|
response := bot.HandleCommand(content.Body, evt.Sender.String(), evt.RoomID.String(),
|
|
ctx, client, evt, &event.RelatesTo{
|
|
InReplyTo: &event.InReplyTo{EventID: evt.ID},
|
|
})
|
|
if response == nil {
|
|
return
|
|
}
|
|
|
|
for _, resp := range response {
|
|
switch r := resp.(type) {
|
|
case event.MessageEventContent:
|
|
_, err := client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, r)
|
|
if err != nil {
|
|
log.Println("Send error:", err)
|
|
}
|
|
default:
|
|
log.Println("Unknown response type")
|
|
}
|
|
}
|
|
})
|
|
ready := make(chan struct{})
|
|
var once sync.Once
|
|
syncer.OnSync(func(ctx context.Context, resp *mautrix.RespSync, since string) bool {
|
|
once.Do(func() { close(ready) })
|
|
return true
|
|
})
|
|
|
|
go func() {
|
|
if err := client.Sync(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
log.Println("Waiting for initial sync...")
|
|
<-ready
|
|
log.Println("Sync complete")
|
|
|
|
if err := verifyWithRecoveryKey(cryptoHelper.Machine()); err != nil {
|
|
if isOlmAccountMismatch(err) {
|
|
log.Println("Detected olm account mismatch with server keys")
|
|
if cryptoResetOnMismatch {
|
|
if resetErr := resetCryptoState(); resetErr != nil {
|
|
log.Fatal(resetErr)
|
|
}
|
|
log.Fatal("Reset crypto state due to mismatch. Restart the container to re-login.")
|
|
}
|
|
log.Fatal("Crypto mismatch. Remove crypto DB and credentials, then restart.")
|
|
}
|
|
log.Fatal(err)
|
|
}
|
|
|
|
log.Println("Bot is running...")
|
|
select {}
|
|
}
|