Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change status package to deal with concrete types instead of interfaces #1171

Merged
merged 5 commits into from
Apr 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion call.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(status.Status); ok {
if _, ok := status.FromError(err); ok {
return err
}
if err == errConnClosing || err == errConnUnavailable {
Expand Down
4 changes: 2 additions & 2 deletions call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func TestInvokeLargeErr(t *testing.T) {
var reply string
req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(status.Status); !ok {
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
Expand All @@ -256,7 +256,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
var reply string
req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(status.Status); !ok {
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
if got, want := ErrorDesc(err), weirdError; got != want {
Expand Down
5 changes: 3 additions & 2 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,10 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {

// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
switch e := err.(type) {
case status.Status:
if _, ok := status.FromError(err); ok {
return err
}
switch e := err.(type) {
case transport.StreamError:
return status.Error(e.Code, e.Desc)
case transport.ConnectionError:
Expand Down
2 changes: 1 addition & 1 deletion rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestToRPCErr(t *testing.T) {
{transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)},
} {
err := toRPCErr(test.errIn)
if _, ok := err.(status.Status); !ok {
if _, ok := status.FromError(err); !ok {
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, ""))
}
if !reflect.DeepEqual(err, test.errOut) {
Expand Down
39 changes: 21 additions & 18 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,25 +682,27 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
}
if err != nil {
switch st := err.(type) {
case status.Status:
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
} else {
switch st := err.(type) {
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
}
default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
}
return err
}

if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := err.(status.Status); ok {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
Expand Down Expand Up @@ -852,15 +854,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
appErr = s.opts.streamInt(server, ss, info, sd.Handler)
}
if appErr != nil {
switch err := appErr.(type) {
case status.Status:
// Do nothing
case transport.StreamError:
appErr = status.Error(err.Code, err.Desc)
default:
appErr = status.Error(convertCode(appErr), appErr.Error())
appStatus, ok := status.FromError(appErr)
if !ok {
switch err := appErr.(type) {
case transport.StreamError:
appStatus = status.New(err.Code, err.Desc)
default:
appStatus = status.New(convertCode(appErr), appErr.Error())
}
appErr = appStatus.Err()
}
appStatus, _ := status.FromError(appErr)
if trInfo != nil {
ss.mu.Lock()
ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
Expand Down
104 changes: 40 additions & 64 deletions status/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,78 +50,56 @@ import (
"google.golang.org/grpc/codes"
)

// Status provides access to grpc status details and is implemented by all
// errors returned from this package except nil errors, which are not typed.
// Note: gRPC users should not implement their own Statuses. Custom data may
// be attached to the spb.Status proto's Details field.
type Status interface {
// Code returns the status code.
Code() codes.Code
// Message returns the status message.
Message() string
// Proto returns a copy of the status in proto form.
Proto() *spb.Status
// Err returns an error representing the status.
Err() error
}

// okStatus is a Status whose Code method returns codes.OK, but does not
// implement error. To represent an OK code as an error, use an untyped nil.
type okStatus struct{}

func (okStatus) Code() codes.Code {
return codes.OK
}

func (okStatus) Message() string {
return ""
}

func (okStatus) Proto() *spb.Status {
return nil
}
// statusError is an alias of a status proto. It implements error and Status,
// and a nil statusError should never be returned by this package.
type statusError spb.Status

func (okStatus) Err() error {
return nil
func (se *statusError) Error() string {
p := (*spb.Status)(se)
return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage())
}

// statusError contains a status proto. It is embedded and not aliased to
// allow for accessor functions of the same name. It implements error and
// Status, and a nil statusError should never be returned by this package.
type statusError struct {
*spb.Status
func (se *statusError) status() *Status {
return &Status{s: (*spb.Status)(se)}
}

func (se *statusError) Error() string {
return fmt.Sprintf("rpc error: code = %s desc = %s", se.Code(), se.Message())
// Status represents an RPC status code, message, and details. It is immutable
// and should be created with New, Newf, or FromProto.
type Status struct {
s *spb.Status
}

func (se *statusError) Code() codes.Code {
return codes.Code(se.Status.Code)
// Code returns the status code contained in s.
func (s *Status) Code() codes.Code {
return codes.Code(s.s.Code)
}

func (se *statusError) Message() string {
return se.Status.Message
// Message returns the message contained in s.
func (s *Status) Message() string {
return s.s.Message
}

func (se *statusError) Proto() *spb.Status {
return proto.Clone(se.Status).(*spb.Status)
// Proto returns s's status as an spb.Status proto message.
func (s *Status) Proto() *spb.Status {
return proto.Clone(s.s).(*spb.Status)
}

func (se *statusError) Err() error {
return se
// Err returns an immutable error representing s; returns nil if s.Code() is
// OK.
func (s *Status) Err() error {
if s.Code() == codes.OK {
return nil
}
return (*statusError)(s.s)
}

// New returns a Status representing c and msg.
func New(c codes.Code, msg string) Status {
if c == codes.OK {
return okStatus{}
}
return &statusError{Status: &spb.Status{Code: int32(c), Message: msg}}
func New(c codes.Code, msg string) *Status {
return &Status{s: &spb.Status{Code: int32(c), Message: msg}}
}

// Newf returns New(c, fmt.Sprintf(format, a...)).
func Newf(c codes.Code, format string, a ...interface{}) Status {
func Newf(c codes.Code, format string, a ...interface{}) *Status {
return New(c, fmt.Sprintf(format, a...))
}

Expand All @@ -140,21 +118,19 @@ func ErrorProto(s *spb.Status) error {
return FromProto(s).Err()
}

// FromProto returns a Status representing s. If s.Code is OK, Message and
// Details may be lost.
func FromProto(s *spb.Status) Status {
if s.GetCode() == int32(codes.OK) {
return okStatus{}
}
return &statusError{Status: proto.Clone(s).(*spb.Status)}
// FromProto returns a Status representing s.
func FromProto(s *spb.Status) *Status {
return &Status{s: proto.Clone(s).(*spb.Status)}
}

// FromError returns a Status representing err if it was produced from this
// package, otherwise it returns nil, false.
func FromError(err error) (s Status, ok bool) {
func FromError(err error) (s *Status, ok bool) {
if err == nil {
return okStatus{}, true
return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true
}
if s, ok := err.(*statusError); ok {
return s.status(), true
}
s, ok = err.(Status)
return s, ok
return nil, false
}
2 changes: 1 addition & 1 deletion status/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestError(t *testing.T) {
if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want {
t.Fatalf("err.Error() = %q; want %q", got, want)
}
s := err.(Status)
s, _ := FromError(err)
if got, want := s.Code(), codes.Internal; got != want {
t.Fatalf("err.Code() = %s; want %s", got, want)
}
Expand Down
2 changes: 1 addition & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(status.Status); ok {
if _, ok := status.FromError(err); ok {
return nil, err
}
if err == errConnClosing || err == errConnUnavailable {
Expand Down
2 changes: 1 addition & 1 deletion transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
}
}

func (ht *serverHandlerTransport) WriteStatus(s *Stream, st status.Status) error {
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
err := ht.do(func() {
ht.writeCommonHeaders(s)

Expand Down
2 changes: 1 addition & 1 deletion transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st status.Status) error {
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
var headersSent, hasHeader bool
s.mu.Lock()
if s.state == streamDone {
Expand Down
4 changes: 2 additions & 2 deletions transport/http_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type decodeState struct {
// statusGen caches the stream status received from the trailer the server
// sent. Client side only. Do not access directly. After all trailers are
// parsed, use the status method to retrieve the status.
statusGen status.Status
statusGen *status.Status
// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
// intended for direct access outside of parsing.
rawStatusCode int32
Expand Down Expand Up @@ -156,7 +156,7 @@ func validContentType(t string) bool {
return true
}

func (d *decodeState) status() status.Status {
func (d *decodeState) status() *status.Status {
if d.statusGen == nil {
// No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg)
Expand Down
10 changes: 5 additions & 5 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ type Stream struct {
// multiple times.
headerDone bool
// the status error received from the server.
status status.Status
status *status.Status
// rstStream indicates whether a RST_STREAM frame needs to be sent
// to the server to signify that this stream is closing.
rstStream bool
Expand Down Expand Up @@ -285,7 +285,7 @@ func (s *Stream) Method() string {
}

// Status returns the status received from the server.
func (s *Stream) Status() status.Status {
func (s *Stream) Status() *status.Status {
return s.status
}

Expand Down Expand Up @@ -334,8 +334,8 @@ func (s *Stream) Read(p []byte) (n int, err error) {
}

// finish sets the stream's state and status, and closes the done channel.
// s.mu must be held by the caller.
func (s *Stream) finish(st status.Status) {
// s.mu must be held by the caller. st must always be non-nil.
func (s *Stream) finish(st *status.Status) {
s.status = st
s.state = streamDone
close(s.done)
Expand Down Expand Up @@ -508,7 +508,7 @@ type ServerTransport interface {

// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
WriteStatus(s *Stream, st status.Status) error
WriteStatus(s *Stream, st *status.Status) error

// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
Expand Down