public static RecordBatchingStateRestoreCallback adapt(final StateRestoreCallback restoreCallback) { Objects.requireNonNull(restoreCallback, "stateRestoreCallback must not be null"); if (restoreCallback instanceof RecordBatchingStateRestoreCallback) { return (RecordBatchingStateRestoreCallback) restoreCallback; } else if (restoreCallback instanceof BatchingStateRestoreCallback) { return records -> { final List<KeyValue<byte[], byte[]>> keyValues = new ArrayList<>(); for (final ConsumerRecord<byte[], byte[]> record : records) { keyValues.add(new KeyValue<>(record.key(), record.value())); } ((BatchingStateRestoreCallback) restoreCallback).restoreAll(keyValues); }; } else { return records -> { for (final ConsumerRecord<byte[], byte[]> record : records) { restoreCallback.restore(record.key(), record.value()); } }; } } }