Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@ func Exchange(ctx context.Context, packet *Packet, addr string) (*Packet, error)

// Exchange sends the packet to the given server and waits for a response. ctx
// must be non-nil.
//
// If the request packet cannot be encoded, the returned error is a
// *MalformedRequestError; such an error indicates a problem with the request
// and will not be resolved by retrying against another server. Network-level
// failures are returned as their underlying error (which usually implements
// net.Error), and an expired or cancelled ctx is reported via ctx.Err().
func (c *Client) Exchange(ctx context.Context, packet *Packet, addr string) (*Packet, error) {
if ctx == nil {
panic("nil context")
}

wire, err := packet.Encode()
if err != nil {
return nil, err
return nil, &MalformedRequestError{Err: err}
}

connNet := c.Net
Expand Down
38 changes: 38 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package radius

import (
"context"
"errors"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -160,3 +161,40 @@ func TestClient_Exchange_nilContext(t *testing.T) {
//lint:ignore SA1012 This test is specifically checking for a nil context
Exchange(nil, req, "")
}

func TestClient_Exchange_malformedRequest(t *testing.T) {
// A packet with an unknown Code cannot be encoded.
req := New(CodeAccessRequest, []byte(`secret`))
req.Code = Code(255)

client := Client{}
resp, err := client.Exchange(context.Background(), req, "127.0.0.1:1")
if resp != nil {
t.Fatalf("got non-nil response (%v); expected nil", resp)
}
if err == nil {
t.Fatal("got nil error; expected one")
}

var malformed *MalformedRequestError
if !errors.As(err, &malformed) {
t.Fatalf("got error %T (%v); expecting *MalformedRequestError", err, err)
}
if malformed.Unwrap() == nil {
t.Fatal("got nil from Unwrap; expected the underlying encode error")
}

// A network error must not be reported as a MalformedRequestError, so that
// callers can tell encoding problems apart from server/connection failures.
netReq := New(CodeAccessRequest, []byte(`secret`))
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
// 203.0.113.0/24 (TEST-NET-3) is reserved and unroutable.
_, netErr := client.Exchange(ctx, netReq, "203.0.113.1:1")
if netErr == nil {
t.Fatal("got nil error for unreachable server; expected one")
}
if errors.As(netErr, &malformed) {
t.Fatalf("network error reported as *MalformedRequestError: %v", netErr)
}
}
32 changes: 32 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,35 @@ type NonAuthenticResponseError struct {
func (e *NonAuthenticResponseError) Error() string {
return `radius: non-authentic response`
}

// MalformedRequestError is returned by Client.Exchange when the request packet
// could not be encoded into its wire format. It indicates a problem with the
// request itself (for example, attributes that are too long or an unknown
// packet Code), so retrying the same request against another server will not
// help.
//
// Callers can use errors.As to distinguish this client-side error from network
// errors (which implement net.Error) when deciding whether to fail over to
// another server:
//
// resp, err := client.Exchange(ctx, packet, addr)
// var malformed *radius.MalformedRequestError
// switch {
// case err == nil:
// // success
// case errors.As(err, &malformed):
// // the request is invalid; do not retry against other servers
// default:
// // network or server error; failover to the next server may help
// }
type MalformedRequestError struct {
Err error
}

func (e *MalformedRequestError) Error() string {
return "radius: malformed request: " + e.Err.Error()
}

func (e *MalformedRequestError) Unwrap() error {
return e.Err
}