113 lines
2.5 KiB
Go
113 lines
2.5 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"synctv/server/internal/room"
|
|
)
|
|
|
|
type Rooms interface {
|
|
Get(ctx context.Context, code string) (room.Room, error)
|
|
}
|
|
|
|
type Proxy struct {
|
|
rooms Rooms
|
|
enabled bool
|
|
client *http.Client
|
|
}
|
|
|
|
func New(rooms Rooms, enabled bool, timeout time.Duration, maxIdleConns int, responseHeaderBytes int64) *Proxy {
|
|
return &Proxy{
|
|
rooms: rooms,
|
|
enabled: enabled,
|
|
client: &http.Client{
|
|
Timeout: timeout,
|
|
Transport: &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
MaxIdleConns: maxIdleConns,
|
|
ResponseHeaderTimeout: timeout,
|
|
MaxResponseHeaderBytes: responseHeaderBytes,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if !p.enabled {
|
|
http.Error(w, "proxy is disabled", http.StatusForbidden)
|
|
return
|
|
}
|
|
code := r.PathValue("code")
|
|
rm, err := p.rooms.Get(r.Context(), code)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusNotFound)
|
|
return
|
|
}
|
|
if rm.Source == nil || rm.Source.Mode != room.SourceModeProxy {
|
|
http.Error(w, "room source is not in proxy mode", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := p.stream(w, r, *rm.Source); err != nil {
|
|
if !errors.Is(err, context.Canceled) {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) stream(w http.ResponseWriter, r *http.Request, src room.Source) error {
|
|
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, src.URL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
copyRequestHeader(req.Header, r.Header, "Range", "User-Agent", "Accept", "Accept-Encoding")
|
|
for k, v := range src.Headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
if src.Username != "" || src.Password != "" {
|
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(src.Username+":"+src.Password)))
|
|
}
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
copyResponseHeaders(w.Header(), resp.Header)
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, err = io.Copy(w, resp.Body)
|
|
return err
|
|
}
|
|
|
|
func copyRequestHeader(dst, src http.Header, keys ...string) {
|
|
for _, key := range keys {
|
|
value := src.Get(key)
|
|
if strings.TrimSpace(value) != "" {
|
|
dst.Set(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyResponseHeaders(dst, src http.Header) {
|
|
for _, key := range []string{
|
|
"Content-Type",
|
|
"Content-Length",
|
|
"Content-Range",
|
|
"Accept-Ranges",
|
|
"Cache-Control",
|
|
"Last-Modified",
|
|
"ETag",
|
|
} {
|
|
for _, value := range src.Values(key) {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|