Skip to content

Commit

Permalink
[dataset] introduce WriteTlv() methods (openthread#9664)
Browse files Browse the repository at this point in the history
This commit adds different flavors of `WriteTlv()` including template
version as `Write<SimpleTlvType>()`. These methods write or update a
TLV in `Dataset`. The new methods replace the previous `SetTlv()`
methods. This new approach introduces type safety checks during
compilation, guaranteeing the use of the correct value type for each
TLV. For instance, `Write<PanIdTlv>()` only accepts `uint16_t` value,
while `Write<NetworkKeyTlv>()` only accepts `NetworkKey` value. This
commit also renames the previous `GetTlv()` to `FindTlv()`.
  • Loading branch information
abtink authored Dec 1, 2023
1 parent 5665de2 commit 685094b
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 180 deletions.
93 changes: 45 additions & 48 deletions src/core/meshcop/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ bool Dataset::IsValid(void) const
return rval;
}

const Tlv *Dataset::GetTlv(Tlv::Type aType) const { return As<Tlv>(Tlv::FindTlv(mTlvs, mLength, aType)); }
const Tlv *Dataset::FindTlv(Tlv::Type aType) const { return As<Tlv>(Tlv::FindTlv(mTlvs, mLength, aType)); }

void Dataset::ConvertTo(Info &aDatasetInfo) const
{
Expand Down Expand Up @@ -286,68 +286,68 @@ Error Dataset::SetFrom(const Info &aDatasetInfo)
Timestamp activeTimestamp;

aDatasetInfo.GetActiveTimestamp(activeTimestamp);
IgnoreError(SetTlv(Tlv::kActiveTimestamp, activeTimestamp));
IgnoreError(Write<ActiveTimestampTlv>(activeTimestamp));
}

if (aDatasetInfo.IsPendingTimestampPresent())
{
Timestamp pendingTimestamp;

aDatasetInfo.GetPendingTimestamp(pendingTimestamp);
IgnoreError(SetTlv(Tlv::kPendingTimestamp, pendingTimestamp));
IgnoreError(Write<PendingTimestampTlv>(pendingTimestamp));
}

if (aDatasetInfo.IsDelayPresent())
{
IgnoreError(SetTlv(Tlv::kDelayTimer, aDatasetInfo.GetDelay()));
IgnoreError(Write<DelayTimerTlv>(aDatasetInfo.GetDelay()));
}

if (aDatasetInfo.IsChannelPresent())
{
ChannelTlv tlv;
tlv.Init();
tlv.SetChannel(aDatasetInfo.GetChannel());
IgnoreError(SetTlv(tlv));
IgnoreError(WriteTlv(tlv));
}

if (aDatasetInfo.IsChannelMaskPresent())
{
ChannelMaskTlv tlv;
tlv.Init();
tlv.SetChannelMask(aDatasetInfo.GetChannelMask());
IgnoreError(SetTlv(tlv));
IgnoreError(WriteTlv(tlv));
}

if (aDatasetInfo.IsExtendedPanIdPresent())
{
IgnoreError(SetTlv(Tlv::kExtendedPanId, aDatasetInfo.GetExtendedPanId()));
IgnoreError(Write<ExtendedPanIdTlv>(aDatasetInfo.GetExtendedPanId()));
}

if (aDatasetInfo.IsMeshLocalPrefixPresent())
{
IgnoreError(SetTlv(Tlv::kMeshLocalPrefix, aDatasetInfo.GetMeshLocalPrefix()));
IgnoreError(Write<MeshLocalPrefixTlv>(aDatasetInfo.GetMeshLocalPrefix()));
}

if (aDatasetInfo.IsNetworkKeyPresent())
{
IgnoreError(SetTlv(Tlv::kNetworkKey, aDatasetInfo.GetNetworkKey()));
IgnoreError(Write<NetworkKeyTlv>(aDatasetInfo.GetNetworkKey()));
}

if (aDatasetInfo.IsNetworkNamePresent())
{
NameData nameData = aDatasetInfo.GetNetworkName().GetAsData();

IgnoreError(SetTlv(Tlv::kNetworkName, nameData.GetBuffer(), nameData.GetLength()));
IgnoreError(WriteTlv(Tlv::kNetworkName, nameData.GetBuffer(), nameData.GetLength()));
}

if (aDatasetInfo.IsPanIdPresent())
{
IgnoreError(SetTlv(Tlv::kPanId, aDatasetInfo.GetPanId()));
IgnoreError(Write<PanIdTlv>(aDatasetInfo.GetPanId()));
}

if (aDatasetInfo.IsPskcPresent())
{
IgnoreError(SetTlv(Tlv::kPskc, aDatasetInfo.GetPskc()));
IgnoreError(Write<PskcTlv>(aDatasetInfo.GetPskc()));
}

if (aDatasetInfo.IsSecurityPolicyPresent())
Expand All @@ -356,7 +356,7 @@ Error Dataset::SetFrom(const Info &aDatasetInfo)

tlv.Init();
tlv.SetSecurityPolicy(aDatasetInfo.GetSecurityPolicy());
IgnoreError(SetTlv(tlv));
IgnoreError(WriteTlv(tlv));
}

mUpdateTime = TimerMilli::GetNow();
Expand All @@ -371,13 +371,13 @@ Error Dataset::GetTimestamp(Type aType, Timestamp &aTimestamp) const

if (aType == kActive)
{
tlv = GetTlv(Tlv::kActiveTimestamp);
tlv = FindTlv(Tlv::kActiveTimestamp);
VerifyOrExit(tlv != nullptr, error = kErrorNotFound);
aTimestamp = tlv->ReadValueAs<ActiveTimestampTlv>();
}
else
{
tlv = GetTlv(Tlv::kPendingTimestamp);
tlv = FindTlv(Tlv::kPendingTimestamp);
VerifyOrExit(tlv != nullptr, error = kErrorNotFound);
aTimestamp = tlv->ReadValueAs<PendingTimestampTlv>();
}
Expand All @@ -388,43 +388,46 @@ Error Dataset::GetTimestamp(Type aType, Timestamp &aTimestamp) const

void Dataset::SetTimestamp(Type aType, const Timestamp &aTimestamp)
{
IgnoreError(SetTlv((aType == kActive) ? Tlv::kActiveTimestamp : Tlv::kPendingTimestamp, aTimestamp));
if (aType == kActive)
{
IgnoreError(Write<ActiveTimestampTlv>(aTimestamp));
}
else
{
IgnoreError(Write<PendingTimestampTlv>(aTimestamp));
}
}

Error Dataset::SetTlv(Tlv::Type aType, const void *aValue, uint8_t aLength)
Error Dataset::WriteTlv(Tlv::Type aType, const void *aValue, uint8_t aLength)
{
Error error = kErrorNone;
uint16_t bytesAvailable = sizeof(mTlvs) - mLength;
Tlv *old = GetTlv(aType);
Tlv tlv;
Tlv *oldTlv = FindTlv(aType);
Tlv *newTlv;

if (old != nullptr)
if (oldTlv != nullptr)
{
bytesAvailable += sizeof(Tlv) + old->GetLength();
bytesAvailable += sizeof(Tlv) + oldTlv->GetLength();
}

VerifyOrExit(sizeof(Tlv) + aLength <= bytesAvailable, error = kErrorNoBufs);

if (old != nullptr)
{
RemoveTlv(old);
}
RemoveTlv(oldTlv);

tlv.SetType(aType);
tlv.SetLength(aLength);
memcpy(mTlvs + mLength, &tlv, sizeof(Tlv));
mLength += sizeof(Tlv);
newTlv = GetTlvsEnd();
mLength += sizeof(Tlv) + aLength;

memcpy(mTlvs + mLength, aValue, aLength);
mLength += aLength;
newTlv->SetType(aType);
newTlv->SetLength(aLength);
memcpy(newTlv->GetValue(), aValue, aLength);

mUpdateTime = TimerMilli::GetNow();

exit:
return error;
}

Error Dataset::SetTlv(const Tlv &aTlv) { return SetTlv(aTlv.GetType(), aTlv.GetValue(), aTlv.GetLength()); }
Error Dataset::WriteTlv(const Tlv &aTlv) { return WriteTlv(aTlv.GetType(), aTlv.GetValue(), aTlv.GetLength()); }

Error Dataset::ReadFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
{
Expand All @@ -444,16 +447,7 @@ Error Dataset::ReadFromMessage(const Message &aMessage, uint16_t aOffset, uint16
return error;
}

void Dataset::RemoveTlv(Tlv::Type aType)
{
Tlv *tlv;

VerifyOrExit((tlv = GetTlv(aType)) != nullptr);
RemoveTlv(tlv);

exit:
return;
}
void Dataset::RemoveTlv(Tlv::Type aType) { RemoveTlv(FindTlv(aType)); }

Error Dataset::AppendMleDatasetTlv(Type aType, Message &aMessage) const
{
Expand Down Expand Up @@ -504,11 +498,14 @@ Error Dataset::AppendMleDatasetTlv(Type aType, Message &aMessage) const

void Dataset::RemoveTlv(Tlv *aTlv)
{
uint8_t *start = reinterpret_cast<uint8_t *>(aTlv);
uint16_t length = sizeof(Tlv) + aTlv->GetLength();
if (aTlv != nullptr)
{
uint8_t *start = reinterpret_cast<uint8_t *>(aTlv);
uint16_t length = sizeof(Tlv) + aTlv->GetLength();

memmove(start, start + length, mLength - (static_cast<uint8_t>(start - mTlvs) + length));
mLength -= length;
memmove(start, start + length, mLength - (static_cast<uint8_t>(start - mTlvs) + length));
mLength -= length;
}
}

Error Dataset::ApplyConfiguration(Instance &aInstance, bool *aIsNetworkKeyUpdated) const
Expand Down Expand Up @@ -609,7 +606,7 @@ void Dataset::SaveTlvInSecureStorageAndClearValue(Tlv::Type aTlvType, Crypto::St
{
using namespace ot::Crypto::Storage;

Tlv *tlv = GetTlv(aTlvType);
Tlv *tlv = FindTlv(aTlvType);

VerifyOrExit(tlv != nullptr);
VerifyOrExit(tlv->GetLength() > 0);
Expand All @@ -628,7 +625,7 @@ Error Dataset::ReadTlvFromSecureStorage(Tlv::Type aTlvType, Crypto::Storage::Key
using namespace ot::Crypto::Storage;

Error error = kErrorNone;
Tlv *tlv = GetTlv(aTlvType);
Tlv *tlv = FindTlv(aTlvType);
size_t readLength;

VerifyOrExit(tlv != nullptr);
Expand Down
Loading

0 comments on commit 685094b

Please sign in to comment.