Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions drivers/123_open/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Open123 struct {
model.Storage
Addition
UID uint64
tm *tokenManager
}

func (d *Open123) Config() driver.Config {
Expand All @@ -33,6 +34,24 @@ func (d *Open123) Init(ctx context.Context) error {
d.UploadThread = 3
}

if d.RefreshToken != "" {
// refresh token 直接主动刷新
d.AccessToken = ""
d.tm = &tokenManager{}
} else {
// 避免个人 token 刷新产生的多个登录,被动刷新
// 默认过期时间90天,jwt exp 不可靠
d.tm = &tokenManager{
// accessToken: d.AccessToken,
expiredAt: time.Now().Add(90 * 24 * time.Hour),
}
}

_, err := d.getAccessToken(false)
if err != nil {
return fmt.Errorf("init get access token error: %w", err)
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/123_open/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Addition struct {
ClientID string `json:"ClientID" required:"false"`
ClientSecret string `json:"ClientSecret" required:"false"`

// 直接写入AccessToken
// 直接写入AccessToken, AccessToken有过期时间,不建议直接填写
AccessToken string `json:"AccessToken" required:"false"`

// 用户名+密码方式登录的AccessToken可以兼容
Expand Down
115 changes: 115 additions & 0 deletions drivers/123_open/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package _123_open

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"

"github.com/OpenListTeam/OpenList/v4/drivers/base"
"github.com/OpenListTeam/OpenList/v4/internal/op"
)

var (
AccessToken = "https://open-api.123pan.com/api/v1/access_token"
RefreshToken = "https://open-api.123pan.com/api/v1/oauth2/access_token"
)

type tokenManager struct {
// accessToken string
expiredAt time.Time
mu sync.Mutex
blockRefresh bool
}

func (d *Open123) getAccessToken(forceRefresh bool) (string, error) {
tm := d.tm
tm.mu.Lock()
defer tm.mu.Unlock()
if tm.blockRefresh {
return "", errors.New("Authentication expired")
}
if !forceRefresh && d.AccessToken != "" && time.Now().Before(tm.expiredAt.Add(-5*time.Minute)) {
return d.AccessToken, nil
}
if err := d.flushAccessToken(); err != nil {
// token expired and failed to refresh, block further refresh attempts
tm.blockRefresh = true
return "", err
}
return d.AccessToken, nil
}

func (d *Open123) flushAccessToken() error {
// directly send request to avoid deadlock
req := base.RestyClient.R()
req.SetHeaders(map[string]string{
"authorization": "Bearer " + d.AccessToken,
"platform": "open_platform",
"Content-Type": "application/json",
})

if d.ClientID != "" {
if d.RefreshToken != "" {
var resp RefreshTokenResp
req.SetQueryParam("client_id", d.ClientID)
if d.ClientSecret != "" {
req.SetQueryParam("client_secret", d.ClientSecret)
}
req.SetQueryParam("grant_type", "refresh_token")
req.SetQueryParam("refresh_token", d.RefreshToken)
req.SetResult(&resp)
res, err := req.Execute(http.MethodPost, RefreshToken)
if err != nil {
return err
}
body := res.Body()
var baseResp BaseResp
if err = json.Unmarshal(body, &baseResp); err != nil {
return err
}
if baseResp.Code != 0 {
return fmt.Errorf("get access token failed: %s", baseResp.Message)
}

d.AccessToken = resp.AccessToken
// add token expire time
d.tm.expiredAt = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
d.RefreshToken = resp.RefreshToken
op.MustSaveDriverStorage(d)
d.tm.blockRefresh = false
return nil
} else if d.ClientSecret != "" {
var resp AccessTokenResp
req.SetBody(base.Json{
"clientID": d.ClientID,
"clientSecret": d.ClientSecret,
})
req.SetResult(&resp)
res, err := req.Execute(http.MethodPost, AccessToken)
if err != nil {
return err
}
body := res.Body()
var baseResp BaseResp
if err = json.Unmarshal(body, &baseResp); err != nil {
return err
}
if baseResp.Code != 0 {
return fmt.Errorf("get access token failed: %s", baseResp.Message)
}
d.AccessToken = resp.Data.AccessToken
// parse token expire time
d.tm.expiredAt, err = time.Parse(time.RFC3339, resp.Data.ExpiredAt)
if err != nil {
return fmt.Errorf("parse expire time failed: %w", err)
}
op.MustSaveDriverStorage(d)
d.tm.blockRefresh = false
return nil
}
}
return errors.New("no valid authentication method available")
}
6 changes: 5 additions & 1 deletion drivers/123_open/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,18 @@ func (d *Open123) Upload(ctx context.Context, file model.FileStreamer, createRes
head := bytes.NewReader(b.Bytes()[:headSize])
tail := bytes.NewReader(b.Bytes()[headSize:])
rateLimitedRd = driver.NewLimitedUploadStream(ctx, io.MultiReader(head, reader, tail))
token, err := d.getAccessToken(false)
if err != nil {
return err
}
// 创建请求并设置header
req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadDomain+"/upload/v2/file/slice", rateLimitedRd)
if err != nil {
return err
}

// 设置请求头
req.Header.Add("Authorization", "Bearer "+d.AccessToken)
req.Header.Add("Authorization", "Bearer "+token)
req.Header.Add("Content-Type", w.FormDataContentType())
req.Header.Add("Platform", "open_platform")

Expand Down
52 changes: 8 additions & 44 deletions drivers/123_open/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/OpenListTeam/OpenList/v4/drivers/base"
"github.com/OpenListTeam/OpenList/v4/internal/op"
"github.com/go-resty/resty/v2"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
Expand All @@ -22,8 +21,6 @@ import (
var ( // 不同情况下获取的AccessTokenQPS限制不同 如下模块化易于拓展
Api = "https://open-api.123pan.com"

AccessToken = InitApiInfo(Api+"/api/v1/access_token", 1)
RefreshToken = InitApiInfo(Api+"/api/v1/oauth2/access_token", 1)
UserInfo = InitApiInfo(Api+"/api/v1/user/info", 1)
FileList = InitApiInfo(Api+"/api/v2/file/list", 3)
DownloadInfo = InitApiInfo(Api+"/api/v1/file/download_info", 5)
Expand All @@ -40,11 +37,14 @@ var ( // 不同情况下获取的AccessTokenQPS限制不同 如下模块化易
)

func (d *Open123) Request(apiInfo *ApiInfo, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) {
retryToken := true
for {
token, err := d.getAccessToken(false)
if err != nil {
return nil, err
}
req := base.RestyClient.R()
req.SetHeaders(map[string]string{
"authorization": "Bearer " + d.AccessToken,
"authorization": "Bearer " + token,
"platform": "open_platform",
"Content-Type": "application/json",
})
Expand Down Expand Up @@ -74,9 +74,9 @@ func (d *Open123) Request(apiInfo *ApiInfo, method string, callback base.ReqCall

if baseResp.Code == 0 {
return body, nil
} else if baseResp.Code == 401 && retryToken {
retryToken = false
if err := d.flushAccessToken(); err != nil {
} else if baseResp.Code == 401 {
// 强制刷新Token, 有小概率会 race condition 导致多次刷新Token,但不影响正确运行
if _, err := d.getAccessToken(true); err != nil {
return nil, err
}
} else if baseResp.Code == 429 {
Expand All @@ -88,42 +88,6 @@ func (d *Open123) Request(apiInfo *ApiInfo, method string, callback base.ReqCall
}
}

func (d *Open123) flushAccessToken() error {
if d.ClientID != "" {
if d.RefreshToken != "" {
var resp RefreshTokenResp
_, err := d.Request(RefreshToken, http.MethodPost, func(req *resty.Request) {
req.SetQueryParam("client_id", d.ClientID)
if d.ClientSecret != "" {
req.SetQueryParam("client_secret", d.ClientSecret)
}
req.SetQueryParam("grant_type", "refresh_token")
req.SetQueryParam("refresh_token", d.RefreshToken)
}, &resp)
if err != nil {
return err
}
d.AccessToken = resp.AccessToken
d.RefreshToken = resp.RefreshToken
op.MustSaveDriverStorage(d)
} else if d.ClientSecret != "" {
var resp AccessTokenResp
_, err := d.Request(AccessToken, http.MethodPost, func(req *resty.Request) {
req.SetBody(base.Json{
"clientID": d.ClientID,
"clientSecret": d.ClientSecret,
})
}, &resp)
if err != nil {
return err
}
d.AccessToken = resp.Data.AccessToken
op.MustSaveDriverStorage(d)
}
}
return nil
}

func (d *Open123) SignURL(originURL, privateKey string, uid uint64, validDuration time.Duration) (newURL string, err error) {
// 生成Unix时间戳
ts := time.Now().Add(validDuration).Unix()
Expand Down