blob: 11ac140ddfe7a57f0ed815224c297891eacc9082 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/x/bsonx"
14 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
15 "github.com/mongodb/mongo-go-driver/x/network/description"
16 "github.com/mongodb/mongo-go-driver/x/network/result"
17 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
18)
19
20// must be sent to admin db
21// { endSessions: [ {id: uuid}, ... ], $clusterTime: ... }
22// only send $clusterTime when gossiping the cluster time
23// send 10k sessions at a time
24
25// EndSessions represents an endSessions command.
26type EndSessions struct {
27 Clock *session.ClusterClock
28 SessionIDs []bsonx.Doc
29
30 results []result.EndSessions
31 errors []error
32}
33
34// BatchSize is the max number of sessions to be included in 1 endSessions command.
35const BatchSize = 10000
36
37func (es *EndSessions) split() [][]bsonx.Doc {
38 batches := [][]bsonx.Doc{}
39 docIndex := 0
40 totalNumDocs := len(es.SessionIDs)
41
42createBatches:
43 for {
44 batch := []bsonx.Doc{}
45
46 for i := 0; i < BatchSize; i++ {
47 if docIndex == totalNumDocs {
48 break createBatches
49 }
50
51 batch = append(batch, es.SessionIDs[docIndex])
52 docIndex++
53 }
54
55 batches = append(batches, batch)
56 }
57
58 return batches
59}
60
61func (es *EndSessions) encodeBatch(batch []bsonx.Doc, desc description.SelectedServer) *Write {
62 vals := make(bsonx.Arr, 0, len(batch))
63 for _, doc := range batch {
64 vals = append(vals, bsonx.Document(doc))
65 }
66
67 cmd := bsonx.Doc{{"endSessions", bsonx.Array(vals)}}
68
69 return &Write{
70 Clock: es.Clock,
71 DB: "admin",
72 Command: cmd,
73 }
74}
75
76// Encode will encode this command into a series of wire messages for the given server description.
77func (es *EndSessions) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
78 cmds := es.encode(desc)
79 wms := make([]wiremessage.WireMessage, len(cmds))
80
81 for _, cmd := range cmds {
82 wm, err := cmd.Encode(desc)
83 if err != nil {
84 return nil, err
85 }
86
87 wms = append(wms, wm)
88 }
89
90 return wms, nil
91}
92
93func (es *EndSessions) encode(desc description.SelectedServer) []*Write {
94 out := []*Write{}
95 batches := es.split()
96
97 for _, batch := range batches {
98 out = append(out, es.encodeBatch(batch, desc))
99 }
100
101 return out
102}
103
104// Decode will decode the wire message using the provided server description. Errors during decoding
105// are deferred until either the Result or Err methods are called.
106func (es *EndSessions) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *EndSessions {
107 rdr, err := (&Write{}).Decode(desc, wm).Result()
108 if err != nil {
109 es.errors = append(es.errors, err)
110 return es
111 }
112
113 return es.decode(desc, rdr)
114}
115
116func (es *EndSessions) decode(desc description.SelectedServer, rdr bson.Raw) *EndSessions {
117 var res result.EndSessions
118 es.errors = append(es.errors, bson.Unmarshal(rdr, &res))
119 es.results = append(es.results, res)
120 return es
121}
122
123// Result returns the results of the decoded wire messages.
124func (es *EndSessions) Result() ([]result.EndSessions, []error) {
125 return es.results, es.errors
126}
127
128// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
129func (es *EndSessions) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) ([]result.EndSessions, []error) {
130 cmds := es.encode(desc)
131
132 for _, cmd := range cmds {
133 rdr, _ := cmd.RoundTrip(ctx, desc, rw) // ignore any errors returned by the command
134 es.decode(desc, rdr)
135 }
136
137 return es.Result()
138}